In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import psutil
from datasets import load_dataset
from transformers import (
    BertTokenizer,
    AutoTokenizer,
    BertForSequenceClassification,
    BertConfig,
    TrainingArguments,
    Trainer,
    PreTrainedModel,
    PretrainedConfig,
    set_seed
)
from transformers.trainer_utils import get_last_checkpoint
from transformers import TrainingArguments, Trainer, TrainerCallback
from mamba_ssm.models.mixer_seq_simple import MixerModel

# 设置随机种子以确保可复现性
set_seed(202504)

# 确定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 创建结果目录
os.makedirs("results", exist_ok=True)
os.makedirs("results/mamba", exist_ok=True) 
os.makedirs("results/transformer", exist_ok=True)
os.makedirs("results/figures", exist_ok=True)

2025-04-15 15:01:02.148833: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744729262.168810  179544 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744729262.174853  179544 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744729262.190408  179544 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744729262.190432  179544 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744729262.190435  179544 computation_placer.cc:177] computation placer alr

Using device: cuda


In [2]:
# 创建Mamba配置类
class MambaConfig(PretrainedConfig):
    model_type = "mamba"
    
    def __init__(
        self,
        vocab_size=50277,
        hidden_size=768,
        intermediate_size=3072,  # 添加中间层大小参数
        state_size=16,
        num_hidden_layers=12,
        num_classes=2,
        pad_token_id=0,
        max_position_embeddings=2048,
        **kwargs
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size  # 保存中间层大小
        self.state_size = state_size
        self.num_hidden_layers = num_hidden_layers
        self.num_classes = num_classes
        self.max_position_embeddings = max_position_embeddings
        super().__init__(pad_token_id=pad_token_id, **kwargs)

# 创建基于Mamba的序列分类模型
class MambaForSequenceClassification(PreTrainedModel):
    config_class = MambaConfig
    
    def __init__(self, config):
        super().__init__(config)
        
        # 初始化Mamba模型 - 添加d_intermediate参数
        self.mamba = MixerModel(
            d_model=config.hidden_size,
            n_layer=config.num_hidden_layers,
            d_intermediate=config.intermediate_size,  # 添加d_intermediate
            vocab_size=config.vocab_size,
            ssm_cfg={"d_state": config.state_size}
        )
        
        # 添加分类头
        self.classifier = nn.Linear(config.hidden_size, config.num_classes)
        
        # 初始化权重
        self._init_weights()
        
    def _init_weights(self):
        # 初始化分类头
        if self.classifier is not None:
            self.classifier.weight.data.normal_(mean=0.0, std=0.02)
            if self.classifier.bias is not None:
                self.classifier.bias.data.zero_()
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # 确保输入类型正确
        if not isinstance(input_ids, torch.Tensor):
            input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device)
            
        # 通过Mamba模型获取序列输出
        # 注意：根据Mamba的实际API调整参数
        outputs = self.mamba(input_ids)
        
        # 使用最后一个token的输出作为整个序列的表示
        sequence_output = outputs[:, -1]
        
        # 通过分类器获取分类结果
        logits = self.classifier(sequence_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
        
        return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}

In [3]:
# 修改评估指标计算函数，以适应不同的返回值结构
def compute_metrics_with_efficiency(eval_pred, model_type, start_time=None):
    # 检查eval_pred的类型和结构
    if hasattr(eval_pred, 'predictions') and hasattr(eval_pred, 'label_ids'):
        # 如果是PredictionOutput类型对象
        predictions = eval_pred.predictions
        labels = eval_pred.label_ids
    elif isinstance(eval_pred, tuple):
        # 如果是元组，按位置解包
        if len(eval_pred) >= 2:
            predictions = eval_pred[0]
            labels = eval_pred[1]
        else:
            raise ValueError(f"Expected at least 2 elements in eval_pred tuple, got {len(eval_pred)}")
    else:
        raise ValueError(f"Unexpected eval_pred type: {type(eval_pred)}")
    
    # 对predictions进行必要的后处理
    if len(predictions.shape) > 1 and predictions.shape[1] > 1:
        # 如果是logits，转换为类别预测
        predictions = np.argmax(predictions, axis=1)
    
    # 计算准确率
    accuracy = np.mean(predictions == labels)
    
    metrics = {"accuracy": accuracy}
    
    # 如果是完整评估（传入start_time），则计算训练时间
    if start_time is not None:
        training_time = time.time() - start_time
        metrics["training_time"] = training_time
    
    # 记录当前内存使用情况
    memory_usage = psutil.Process().memory_info().rss / 1024 ** 2  # MB
    metrics["memory_usage"] = memory_usage
    
    return metrics

# 推理速度测试函数
def test_inference_speed(model, tokenizer, sentences, device, max_length=128):
    # 将输入移至设备
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # 根据模型类型移除不需要的参数
    if "MambaForSequenceClassification" in model.__class__.__name__:
        # Mamba模型不使用token_type_ids
        if 'token_type_ids' in inputs:
            del inputs['token_type_ids']
    
    # 预热
    for _ in range(5):
        _ = model(**inputs)
    
    # 测速
    start_time = time.time()
    for _ in range(100):  # 多次测试取平均
        _ = model(**inputs)
    end_time = time.time()
    
    avg_time = (end_time - start_time) / 100
    return avg_time

# 改进的长序列处理能力测试函数
def test_long_sequence_performance(model_name, tokenizer, model_class, config, device, lengths=[128, 256, 512, 1024, 2048, 4096, 8192, 16384]):
    results = {}
    
    for length in lengths:
        print(f"Testing sequence length: {length}")
        
        # 创建一个新的适配这个长度的模型
        if "mamba" in model_name.lower():
            # Mamba模型配置
            test_config = MambaConfig(**config.to_dict())
            test_config.max_position_embeddings = length
            model = model_class(test_config).to(device)
        else:  # transformer
            # 检查是否是预训练配置还是自定义配置
            if hasattr(config, 'to_dict'):
                test_config = BertConfig(**config.to_dict())
                test_config.max_position_embeddings = length
                model = model_class(test_config).to(device)
            else:
                # 如果是预训练模型配置，直接使用模型类创建
                test_config = BertConfig(
                    vocab_size=config.vocab_size,
                    hidden_size=config.hidden_size,
                    num_hidden_layers=config.num_hidden_layers,
                    num_attention_heads=config.num_attention_heads,
                    intermediate_size=config.intermediate_size,
                    max_position_embeddings=length,  # 调整最大序列长度
                    num_labels=config.num_labels
                )
                model = model_class(test_config).to(device)
        
        # 生成一个随机的测试序列
        random_tokens = torch.randint(100, 5000, (1, length)).to(device)
        attention_mask = torch.ones_like(random_tokens)
        
        # 预热
        for _ in range(3):
            if "mamba" in model_name.lower():
                _ = model(random_tokens, attention_mask=attention_mask)
            else:
                _ = model(random_tokens, attention_mask=attention_mask)
        
        # 测速
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        start_time = time.time()
        for _ in range(10):
            if "mamba" in model_name.lower():
                _ = model(random_tokens, attention_mask=attention_mask)
            else:
                _ = model(random_tokens, attention_mask=attention_mask)
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        end_time = time.time()
        
        # 记录结果
        avg_time = (end_time - start_time) / 10
        if torch.cuda.is_available():
            memory_usage = torch.cuda.max_memory_allocated() / 1024**2  # MB
        else:
            memory_usage = psutil.Process().memory_info().rss / 1024**2  # MB
        
        results[length] = {"time": avg_time, "memory": memory_usage}
        
        # 清理内存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return results

# 统一的数据处理函数
def process_sst2_data(tokenizer, max_length=128, test_mode=False):
    # 加载SST-2数据集
    datasets = load_dataset("glue", "sst2")
    
    # 测试模式下只使用一小部分数据
    if test_mode:
        # 只使用前100个样本
        test_size = 100
        datasets["train"] = datasets["train"].select(range(test_size))
        datasets["validation"] = datasets["validation"].select(range(test_size))
        print(f"TEST MODE: Using only {test_size} samples for training and validation")
    
    # 定义预处理函数
    def preprocess_function(examples):
        return tokenizer(
            examples["sentence"],
            truncation=True,
            padding="max_length",
            max_length=max_length
        )
    
    # 应用预处理
    tokenized_datasets = datasets.map(preprocess_function, batched=True)
    
    # 设置数据格式
    tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
    tokenized_datasets.set_format("torch")
    
    return tokenized_datasets

In [4]:
# 改进的Mamba模型训练与评估函数
def train_evaluate_mamba(test_mode=False):
    print(f"\n{'='*50}")
    print(f"Running Mamba experiment {'(TEST MODE)' if test_mode else ''} (FROM SCRATCH)")
    print(f"{'='*50}")
    
    # Mamba实验配置 - 参数量与Transformer匹配
    if test_mode:
        mamba_config = MambaConfig(
            vocab_size=30522,
            hidden_size=128,
            intermediate_size=512,
            num_hidden_layers=2,
            state_size=8,
            pad_token_id=0,
            num_classes=2,
        )
    else:
        # 调整后的配置，参数量约5400万
        mamba_config = MambaConfig(
            vocab_size=30522,
            hidden_size=512,
            intermediate_size=2048,
            num_hidden_layers=8,
            state_size=16,
            pad_token_id=0,
            num_classes=2,
        )
    
    # 使用与Transformer相同的tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", local_files_only=True)
        print("使用本地缓存的tokenizer")
    except:
        print("创建基本tokenizer...")
        from transformers import BertTokenizer
        vocab_file = os.path.join("vocab.txt")
        if not os.path.exists(vocab_file):
            # 如果没有vocab文件，创建一个简单的
            print("创建简单的词汇表...")
            basic_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
            for i in range(97, 123):  # a-z
                basic_vocab.append(chr(i))
            for word in ["the", "a", "an", "and", "or", "but", "if", "then", "else", "this", "that"]:
                basic_vocab.append(word)
            with open(vocab_file, "w") as f:
                for word in basic_vocab:
                    f.write(word + "\n")
        
        tokenizer = BertTokenizer(vocab_file=vocab_file, do_lower_case=True)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # 数据处理
    tokenized_datasets = process_sst2_data(tokenizer, max_length=128, test_mode=test_mode)
    
    # 创建模型
    model = MambaForSequenceClassification(mamba_config).to(device)
    
    # 打印模型参数量
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Mamba model parameters: {param_count:,}")
    
    # 使用与Transformer相似的学习率
    mamba_lr = 1e-4
    
    # 设置训练参数 - 匹配Transformer的训练条件
    training_args = TrainingArguments(
        output_dir=f"./results/mamba/{'test' if test_mode else 'full'}_scratch",
        learning_rate=mamba_lr,
        per_device_train_batch_size=16 if test_mode else 32,
        per_device_eval_batch_size=16 if test_mode else 64,
        num_train_epochs=1 if test_mode else 6,  # 匹配Transformer的训练轮数
        weight_decay=0.01,
        save_strategy="steps",  # 确保保存检查点
        save_steps=250,         # 每250步保存一次
        save_total_limit=2,     # 只保留最近的2个检查点
        seed=202504,
        eval_strategy="steps",  # 定期评估
        eval_steps=250,         # 每250步评估一次
        logging_steps=250,      # 每250步记录一次
        # 添加warm-up步骤
        warmup_ratio=0.1,
        # 使用线性学习率调度
        lr_scheduler_type="linear",
        # 添加梯度累积步骤
        gradient_accumulation_steps=2,
        # 添加梯度裁剪
        max_grad_norm=1.0,
        # 使用AdamW优化器
        optim="adamw_torch",
        # 加载最佳模型进行评估
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
    )
    
    # 创建优化回调以监控Mamba的训练过程
    class MambaTrainingCallback(TrainerCallback):
        def __init__(self):
            self.best_accuracy = 0.0
            self.best_step = 0
            self.loss_history = []
            self.accuracy_history = []
            
        def on_evaluate(self, args, state, control, metrics=None, **kwargs):
            if metrics and "eval_accuracy" in metrics:
                current_accuracy = metrics["eval_accuracy"]
                self.accuracy_history.append((state.global_step, current_accuracy))
                
                print(f"Step {state.global_step}: Accuracy = {current_accuracy:.4f}")
                
                if current_accuracy > self.best_accuracy:
                    self.best_accuracy = current_accuracy
                    self.best_step = state.global_step
                    print(f"New best accuracy: {self.best_accuracy:.4f} at step {self.best_step}")
                    
                    # 手动保存最佳模型
                    model = kwargs.get('model')
                    if model is not None:
                        best_model_dir = os.path.join(args.output_dir, "best_model")
                        os.makedirs(best_model_dir, exist_ok=True)
                        model.save_pretrained(best_model_dir)
                        print(f"已手动保存最佳模型到 {best_model_dir}")
                    
        def on_log(self, args, state, control, logs=None, **kwargs):
            if logs and "loss" in logs:
                self.loss_history.append((state.global_step, logs["loss"]))
                
                # 每500步打印训练曲线
                if state.global_step % 500 == 0 and len(self.loss_history) > 1:
                    recent_losses = [l[1] for l in self.loss_history[-10:]]
                    avg_loss = sum(recent_losses) / len(recent_losses)
                    print(f"Step {state.global_step}: Recent average loss = {avg_loss:.4f}")
                    
        def on_train_end(self, args, state, control, **kwargs):
            print(f"Training completed. Best accuracy: {self.best_accuracy:.4f} at step {self.best_step}")
            
            # 保存训练曲线
            if len(self.loss_history) > 0 and len(self.accuracy_history) > 0:
                steps, losses = zip(*self.loss_history)
                acc_steps, accuracies = zip(*self.accuracy_history)
                
                plt.figure(figsize=(12, 5))
                
                plt.subplot(1, 2, 1)
                plt.plot(steps, losses)
                plt.title('Training Loss')
                plt.xlabel('Step')
                plt.ylabel('Loss')
                
                plt.subplot(1, 2, 2)
                plt.plot(acc_steps, accuracies)
                plt.title('Validation Accuracy')
                plt.xlabel('Step')
                plt.ylabel('Accuracy')
                
                plt.tight_layout()
                os.makedirs('results/mamba/figures', exist_ok=True)
                plt.savefig('results/mamba/figures/training_curves.png')
                plt.close()
                
                # 保存训练历史到CSV
                import pandas as pd
                history_df = pd.DataFrame({
                    'step': steps,
                    'loss': losses
                })
                history_df.to_csv('results/mamba/figures/loss_history.csv', index=False)
                
                accuracy_df = pd.DataFrame({
                    'step': acc_steps,
                    'accuracy': accuracies
                })
                accuracy_df.to_csv('results/mamba/figures/accuracy_history.csv', index=False)
    
    # 创建训练器
    mamba_callback = MambaTrainingCallback()
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        compute_metrics=lambda eval_pred: compute_metrics_with_efficiency(
            eval_pred, "mamba", start_time=None
        ),
        callbacks=[mamba_callback],
        processing_class=tokenizer,  # 传入tokenizer以便保存
    )
    
    # 记录开始时间
    start_time = time.time()
    
    # 训练
    trainer.train()
    
    # 加载最佳模型进行最终评估
    best_model_dir = os.path.join(training_args.output_dir, "best_model")
    if os.path.exists(best_model_dir):
        print(f"加载最佳模型进行最终评估...")
        model = MambaForSequenceClassification.from_pretrained(best_model_dir).to(device)
        trainer.model = model
    else:
        print("未找到手动保存的最佳模型，使用Trainer保存的最佳模型或当前模型")
    
    # 评估
    eval_results = trainer.evaluate()
    
    # 计算最终指标
    eval_pred = trainer.predict(tokenized_datasets["validation"])
    metrics = compute_metrics_with_efficiency(eval_pred, "mamba")
    
    final_metrics = {
        "accuracy": metrics["accuracy"],
        "training_time": time.time() - start_time,
        "memory_usage": psutil.Process().memory_info().rss / 1024 ** 2
    }
    
    # 获取预测和标签进行详细评估
    from sklearn.metrics import precision_recall_fscore_support, roc_curve, auc, confusion_matrix
    
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids
    
    # 获取预测类别和概率
    preds = np.argmax(predictions, axis=1)
    probabilities = torch.nn.functional.softmax(torch.tensor(predictions), dim=1).numpy()
    
    # 计算详细指标
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    confusion_mat = confusion_matrix(labels, preds)
    fpr, tpr, _ = roc_curve(labels, probabilities[:, 1])
    roc_auc = auc(fpr, tpr)
    
    # 添加到final_metrics
    final_metrics["precision"] = precision
    final_metrics["recall"] = recall
    final_metrics["f1"] = f1
    final_metrics["roc_auc"] = roc_auc
    final_metrics["confusion_matrix"] = confusion_mat
    final_metrics["fpr"] = fpr
    final_metrics["tpr"] = tpr
    
    # 绘制并保存ROC曲线
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Mamba ROC Curve')
    plt.legend(loc="lower right")
    os.makedirs('results/mamba/metrics', exist_ok=True)
    plt.savefig('results/mamba/metrics/roc_curve.png')
    plt.close()
    
    # 绘制并保存混淆矩阵
    plt.figure(figsize=(8, 6))
    plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Mamba Confusion Matrix')
    plt.colorbar()
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    # 添加数值标签
    for i in range(confusion_mat.shape[0]):
        for j in range(confusion_mat.shape[1]):
            plt.text(j, i, str(confusion_mat[i, j]),
                    ha="center", va="center",
                    color="white" if confusion_mat[i, j] > confusion_mat.max() / 2 else "black")
    
    plt.savefig('results/mamba/metrics/confusion_matrix.png')
    plt.close()
    
    # 测试模式下，简化长序列测试
    if test_mode:
        # 只测试一个短长度序列
        test_sentences = ["This is a test sentence."]
        print("Quick inference speed test...")
        inference_speed = test_inference_speed(
            model, tokenizer, [test_sentences[0]], device, max_length=32
        )
        final_metrics["inference_speed"] = inference_speed
        
        # 简化长序列测试，只测试一两个长度
        print("Quick long sequence test...")
        long_seq_results = test_long_sequence_performance(
            "mamba", tokenizer, MambaForSequenceClassification, mamba_config, device,
            lengths=[32, 64] if test_mode else [128, 256, 512, 1024, 2048, 4096, 8192, 16384]
        )
    else:
        # 完整测试
        test_sentences = [
            "This movie was great!", 
            "I hated this film.", 
            "A wonderful experience that I would recommend to everyone.",
            "The acting was mediocre at best, and the plot was predictable.",
            "I've never been so bored watching a film in my entire life."
        ]
        
        print("Testing inference speed...")
        inference_speed = test_inference_speed(
            model, tokenizer, test_sentences, device
        )
        final_metrics["inference_speed"] = inference_speed
        
        # 长序列测试
        print("Testing long sequence performance...")
        long_seq_results = test_long_sequence_performance(
            "mamba", tokenizer, MambaForSequenceClassification, mamba_config, device
        )
    
    final_metrics["long_sequence_results"] = long_seq_results
    
    # 保存详细评估指标到文件
    with open('results/mamba/metrics/detailed_metrics.txt', 'w') as f:
        f.write(f"MAMBA DETAILED METRICS\n")
        f.write("="*50 + "\n\n")
        f.write(f"Accuracy: {final_metrics['accuracy']:.4f}\n")
        f.write(f"Precision: {precision:.4f}\n")
        f.write(f"Recall: {recall:.4f}\n")
        f.write(f"F1 Score: {f1:.4f}\n")
        f.write(f"ROC AUC: {roc_auc:.4f}\n\n")
        
        f.write("Training Information:\n")
        f.write(f"Training time: {final_metrics['training_time']:.2f} seconds\n")
        f.write(f"Inference speed: {final_metrics['inference_speed']:.5f} seconds/sample\n")
        f.write(f"Memory usage: {final_metrics['memory_usage']:.2f} MB\n\n")
        
        f.write("Confusion Matrix:\n")
        f.write(str(confusion_mat) + "\n\n")
        
        f.write("Long Sequence Performance:\n")
        for length, results in long_seq_results.items():
            f.write(f"Length {length}: Time = {results['time']:.5f}s, Memory = {results['memory']:.2f} MB\n")
    
    # 保存Mamba结果摘要
    with open("results/mamba/summary_scratch.txt", "w") as f:
        f.write(f"MAMBA RESULTS (FROM SCRATCH) {'(TEST MODE)' if test_mode else ''}\\n")
        f.write("="*50 + "\\n")
        f.write(f"Accuracy: {final_metrics['accuracy']:.4f}\\n")
        f.write(f"Precision: {precision:.4f}\\n")
        f.write(f"Recall: {recall:.4f}\\n")
        f.write(f"F1 Score: {f1:.4f}\\n")
        f.write(f"ROC AUC: {roc_auc:.4f}\\n\\n")
        f.write(f"Training time: {final_metrics['training_time']:.2f} seconds\\n")
        f.write(f"Inference speed: {final_metrics['inference_speed']:.5f} seconds/sample\\n")
        f.write(f"Memory usage: {final_metrics['memory_usage']:.2f} MB\\n\\n")
        
        f.write("LONG SEQUENCE PERFORMANCE\\n")
        f.write("="*50 + "\\n")
        for length, results in long_seq_results.items():
            f.write(f"Length {length}: Time = {results['time']:.5f}s, Memory = {results['memory']:.2f} MB\\n")
    
    return final_metrics

In [5]:
# 改进的Transformer训练函数
def train_evaluate_transformer(test_mode=False):
    print(f"\n{'='*50}")
    print(f"Running Transformer experiment {'(TEST MODE)' if test_mode else ''} (FROM SCRATCH)")
    print(f"{'='*50}")
    
    # 使用本地tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", local_files_only=True)
        print("使用本地缓存的tokenizer")
    except:
        print("创建基本tokenizer...")
        from transformers import BertTokenizer
        vocab_file = os.path.join("vocab.txt")
        if not os.path.exists(vocab_file):
            # 如果没有vocab文件，创建一个简单的
            print("创建简单的词汇表...")
            basic_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
            for i in range(97, 123):  # a-z
                basic_vocab.append(chr(i))
            for word in ["the", "a", "an", "and", "or", "but", "if", "then", "else", "this", "that"]:
                basic_vocab.append(word)
            with open(vocab_file, "w") as f:
                for word in basic_vocab:
                    f.write(word + "\n")
        
        tokenizer = BertTokenizer(vocab_file=vocab_file, do_lower_case=True)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # 数据处理
    tokenized_datasets = process_sst2_data(tokenizer, max_length=128, test_mode=test_mode)
    
    # 创建模型 - 使用优化配置
    if test_mode:
        # 测试模式下使用较小的模型配置
        transformer_config = BertConfig(
            vocab_size=30522,
            hidden_size=128,
            num_hidden_layers=2,
            num_attention_heads=2,
            intermediate_size=512,
            hidden_act="gelu",
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=2,
            initializer_range=0.08,
            layer_norm_eps=1e-12,
            pad_token_id=0,
            num_labels=2,
        )
    else:
        # 从头创建优化的BERT模型
        print("从头创建优化的BERT模型...")
        transformer_config = BertConfig(
            vocab_size=30522,
            hidden_size=768,
            num_hidden_layers=6,  # 6层
            num_attention_heads=12,
            intermediate_size=3072,
            hidden_act="gelu",
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=2,
            initializer_range=0.08,  # 增加初始化范围
            layer_norm_eps=1e-12,
            pad_token_id=0,
            num_labels=2,
        )
    
    model = BertForSequenceClassification(transformer_config).to(device)
    
    # 打印模型参数量
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Transformer model parameters: {param_count:,}")
    
    # 设置训练参数
    transformer_lr = 1e-4  # 调整学习率
    
    # 设置训练参数 - 调整以优化从头训练性能
    training_args = TrainingArguments(
        output_dir=f"./results/transformer/{'test' if test_mode else 'full'}_scratch",
        learning_rate=transformer_lr,
        per_device_train_batch_size=16 if test_mode else 32,
        per_device_eval_batch_size=16 if test_mode else 64,
        num_train_epochs=1 if test_mode else 6,
        weight_decay=0.01,
        save_strategy="steps",  # 确保保存检查点
        save_steps=250,         # 每250步保存一次
        save_total_limit=2,     # 只保留最近的2个检查点
        seed=202504,
        eval_strategy="steps",  # 定期评估
        eval_steps=250,         # 每250步评估一次
        logging_steps=250,      # 每250步记录一次
        # 使用比例预热，更加灵活
        warmup_ratio=0.1,
        # 使用线性学习率调度
        lr_scheduler_type="linear",
        # 增加梯度累积步骤，帮助稳定训练
        gradient_accumulation_steps=2,
        # 添加梯度裁剪，避免梯度爆炸
        max_grad_norm=1.0,
        # 使用AdamW优化器
        optim="adamw_torch",
        # 加载最佳模型进行评估
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
    )
    
    # 创建一个自定义回调来监控训练过程
    class TransformerTrainingCallback(TrainerCallback):
        def __init__(self):
            self.best_accuracy = 0.0
            self.best_step = 0
            self.loss_history = []
            self.accuracy_history = []
            
        def on_evaluate(self, args, state, control, metrics=None, **kwargs):
            if metrics and "eval_accuracy" in metrics:
                current_accuracy = metrics["eval_accuracy"]
                self.accuracy_history.append((state.global_step, current_accuracy))
                
                print(f"Step {state.global_step}: Accuracy = {current_accuracy:.4f}")
                
                if current_accuracy > self.best_accuracy:
                    self.best_accuracy = current_accuracy
                    self.best_step = state.global_step
                    print(f"New best accuracy: {self.best_accuracy:.4f} at step {self.best_step}")
                    
                    # 手动保存最佳模型
                    model = kwargs.get('model')
                    if model is not None:
                        best_model_dir = os.path.join(args.output_dir, "best_model")
                        os.makedirs(best_model_dir, exist_ok=True)
                        model.save_pretrained(best_model_dir)
                        print(f"已手动保存最佳模型到 {best_model_dir}")
                    
        def on_log(self, args, state, control, logs=None, **kwargs):
            if logs and "loss" in logs:
                self.loss_history.append((state.global_step, logs["loss"]))
                
                # 每500步打印训练曲线
                if state.global_step % 500 == 0 and len(self.loss_history) > 1:
                    recent_losses = [l[1] for l in self.loss_history[-10:]]
                    avg_loss = sum(recent_losses) / len(recent_losses)
                    print(f"Step {state.global_step}: Recent average loss = {avg_loss:.4f}")
                    
        def on_train_end(self, args, state, control, **kwargs):
            print(f"Training completed. Best accuracy: {self.best_accuracy:.4f} at step {self.best_step}")
            
            # 保存训练曲线
            if len(self.loss_history) > 0 and len(self.accuracy_history) > 0:
                steps, losses = zip(*self.loss_history)
                acc_steps, accuracies = zip(*self.accuracy_history)
                
                plt.figure(figsize=(12, 5))
                
                plt.subplot(1, 2, 1)
                plt.plot(steps, losses)
                plt.title('Training Loss')
                plt.xlabel('Step')
                plt.ylabel('Loss')
                
                plt.subplot(1, 2, 2)
                plt.plot(acc_steps, accuracies)
                plt.title('Validation Accuracy')
                plt.xlabel('Step')
                plt.ylabel('Accuracy')
                
                plt.tight_layout()
                os.makedirs('results/transformer/figures', exist_ok=True)
                plt.savefig('results/transformer/figures/training_curves.png')
                plt.close()
                
                # 保存训练历史到CSV
                import pandas as pd
                history_df = pd.DataFrame({
                    'step': steps,
                    'loss': losses
                })
                history_df.to_csv('results/transformer/figures/loss_history.csv', index=False)
                
                accuracy_df = pd.DataFrame({
                    'step': acc_steps,
                    'accuracy': accuracies
                })
                accuracy_df.to_csv('results/transformer/figures/accuracy_history.csv', index=False)
    
    # 创建训练器
    transformer_callback = TransformerTrainingCallback()
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        compute_metrics=lambda eval_pred: compute_metrics_with_efficiency(
            eval_pred, "transformer", start_time=None
        ),
        callbacks=[transformer_callback],
        processing_class=tokenizer,  # 传入tokenizer以便保存
    )
    
    # 记录开始时间
    start_time = time.time()
    
    # 训练
    trainer.train()
    
    # 加载最佳模型进行最终评估
    best_model_dir = os.path.join(training_args.output_dir, "best_model")
    if os.path.exists(best_model_dir):
        print(f"加载最佳模型进行最终评估...")
        model = BertForSequenceClassification.from_pretrained(best_model_dir).to(device)
        trainer.model = model
    else:
        print("未找到手动保存的最佳模型，使用Trainer保存的最佳模型或当前模型")
    
    # 评估
    eval_results = trainer.evaluate()
    
    # 计算最终指标
    eval_pred = trainer.predict(tokenized_datasets["validation"])
    metrics = compute_metrics_with_efficiency(eval_pred, "transformer")
    
    final_metrics = {
        "accuracy": metrics["accuracy"],
        "training_time": time.time() - start_time,
        "memory_usage": psutil.Process().memory_info().rss / 1024 ** 2
    }
    
    # 获取预测和标签进行详细评估
    from sklearn.metrics import precision_recall_fscore_support, roc_curve, auc, confusion_matrix
    
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids
    
    # 获取预测类别和概率
    preds = np.argmax(predictions, axis=1)
    probabilities = torch.nn.functional.softmax(torch.tensor(predictions), dim=1).numpy()
    
    # 计算详细指标
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    confusion_mat = confusion_matrix(labels, preds)
    fpr, tpr, _ = roc_curve(labels, probabilities[:, 1])
    roc_auc = auc(fpr, tpr)
    
    # 添加到final_metrics
    final_metrics["precision"] = precision
    final_metrics["recall"] = recall
    final_metrics["f1"] = f1
    final_metrics["roc_auc"] = roc_auc
    final_metrics["confusion_matrix"] = confusion_mat
    final_metrics["fpr"] = fpr
    final_metrics["tpr"] = tpr
    
    # 绘制并保存ROC曲线
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Transformer ROC Curve')
    plt.legend(loc="lower right")
    os.makedirs('results/transformer/metrics', exist_ok=True)
    plt.savefig('results/transformer/metrics/roc_curve.png')
    plt.close()
    
    # 绘制并保存混淆矩阵
    plt.figure(figsize=(8, 6))
    plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Transformer Confusion Matrix')
    plt.colorbar()
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    # 添加数值标签
    for i in range(confusion_mat.shape[0]):
        for j in range(confusion_mat.shape[1]):
            plt.text(j, i, str(confusion_mat[i, j]),
                    ha="center", va="center",
                    color="white" if confusion_mat[i, j] > confusion_mat.max() / 2 else "black")
    
    plt.savefig('results/transformer/metrics/confusion_matrix.png')
    plt.close()
    
    # 测试模式下，简化长序列测试
    if test_mode:
        # 只测试一个短长度序列
        test_sentences = ["This is a test sentence."]
        print("Quick inference speed test...")
        inference_speed = test_inference_speed(
            model, tokenizer, [test_sentences[0]], device, max_length=32
        )
        final_metrics["inference_speed"] = inference_speed
        
        # 简化长序列测试，只测试一两个长度
        print("Quick long sequence test...")
        long_seq_results = test_long_sequence_performance(
            "transformer", tokenizer, BertForSequenceClassification, transformer_config, device,
            lengths=[32, 64] if test_mode else [128, 256, 512, 1024, 2048, 4096, 8192, 16384]
        )
    else:
        # 完整测试
        test_sentences = [
            "This movie was great!", 
            "I hated this film.", 
            "A wonderful experience that I would recommend to everyone.",
            "The acting was mediocre at best, and the plot was predictable.",
            "I've never been so bored watching a film in my entire life."
        ]
        
        print("Testing inference speed...")
        inference_speed = test_inference_speed(
            model, tokenizer, test_sentences, device
        )
        final_metrics["inference_speed"] = inference_speed
        
        # 长序列测试
        print("Testing long sequence performance...")
        long_seq_results = test_long_sequence_performance(
            "transformer", tokenizer, BertForSequenceClassification, transformer_config, device
        )
    
    final_metrics["long_sequence_results"] = long_seq_results
    
    # 保存详细评估指标到文件
    with open('results/transformer/metrics/detailed_metrics.txt', 'w') as f:
        f.write(f"TRANSFORMER DETAILED METRICS\n")
        f.write("="*50 + "\n\n")
        f.write(f"Accuracy: {final_metrics['accuracy']:.4f}\n")
        f.write(f"Precision: {precision:.4f}\n")
        f.write(f"Recall: {recall:.4f}\n")
        f.write(f"F1 Score: {f1:.4f}\n")
        f.write(f"ROC AUC: {roc_auc:.4f}\n\n")
        
        f.write("Training Information:\n")
        f.write(f"Training time: {final_metrics['training_time']:.2f} seconds\n")
        f.write(f"Inference speed: {final_metrics['inference_speed']:.5f} seconds/sample\n")
        f.write(f"Memory usage: {final_metrics['memory_usage']:.2f} MB\n\n")
        
        f.write("Confusion Matrix:\n")
        f.write(str(confusion_mat) + "\n\n")
        
        f.write("Long Sequence Performance:\n")
        for length, results in long_seq_results.items():
            f.write(f"Length {length}: Time = {results['time']:.5f}s, Memory = {results['memory']:.2f} MB\n")
    
    # 保存Transformer结果摘要
    with open("results/transformer/summary_scratch.txt", "w") as f:
        f.write(f"TRANSFORMER RESULTS (FROM SCRATCH) {'(TEST MODE)' if test_mode else ''}\\n")
        f.write("="*50 + "\\n")
        f.write(f"Accuracy: {final_metrics['accuracy']:.4f}\\n")
        f.write(f"Precision: {precision:.4f}\\n")
        f.write(f"Recall: {recall:.4f}\\n")
        f.write(f"F1 Score: {f1:.4f}\\n")
        f.write(f"ROC AUC: {roc_auc:.4f}\\n\\n")
        f.write(f"Training time: {final_metrics['training_time']:.2f} seconds\\n")
        f.write(f"Inference speed: {final_metrics['inference_speed']:.5f} seconds/sample\\n")
        f.write(f"Memory usage: {final_metrics['memory_usage']:.2f} MB\\n\\n")
        
        f.write("LONG SEQUENCE PERFORMANCE\\n")
        f.write("="*50 + "\\n")
        for length, results in long_seq_results.items():
            f.write(f"Length {length}: Time = {results['time']:.5f}s, Memory = {results['memory']:.2f} MB\\n")
    
    return final_metrics

In [6]:
# 结果可视化函数
def visualize_results(mamba_results, transformer_results):
    # 创建结果目录
    os.makedirs("results/figures", exist_ok=True)
    
    # 准确率和速度对比
    metrics = ["accuracy", "training_time", "inference_speed", "memory_usage"]
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, metric in enumerate(metrics):
        data = {
            "Mamba": mamba_results[metric],
            "Transformer": transformer_results[metric]
        }
        ax = axes[i]
        bars = ax.bar(data.keys(), data.values())
        ax.set_title(f"{metric.replace('_', ' ').title()}")
        
        # 添加数值标签
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f"{height:.4f}",
                    ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig("results/figures/basic_metrics_comparison.png")
    plt.close()
    
    # 长序列性能对比
    lengths = list(mamba_results["long_sequence_results"].keys())
    
    # 时间对比
    mamba_times = [mamba_results["long_sequence_results"][l]["time"] for l in lengths]
    transformer_times = [transformer_results["long_sequence_results"][l]["time"] for l in lengths]
    
    plt.figure(figsize=(10, 6))
    plt.plot(lengths, mamba_times, 'o-', label='Mamba')
    plt.plot(lengths, transformer_times, 's-', label='Transformer')
    plt.xlabel('Sequence Length')
    plt.ylabel('Inference Time (s)')
    plt.title('Inference Time vs Sequence Length')
    plt.legend()
    plt.grid(True)
    plt.savefig("results/figures/sequence_length_time.png")
    plt.close()
    
    # 内存对比
    mamba_memory = [mamba_results["long_sequence_results"][l]["memory"] for l in lengths]
    transformer_memory = [transformer_results["long_sequence_results"][l]["memory"] for l in lengths]
    
    plt.figure(figsize=(10, 6))
    plt.plot(lengths, mamba_memory, 'o-', label='Mamba')
    plt.plot(lengths, transformer_memory, 's-', label='Transformer')
    plt.xlabel('Sequence Length')
    plt.ylabel('Memory Usage (MB)')
    plt.title('Memory Usage vs Sequence Length')
    plt.legend()
    plt.grid(True)
    plt.savefig("results/figures/sequence_length_memory.png")
    plt.close()
    
    # 创建性能比率图 - 这将显示Mamba相对于Transformer的优势
    plt.figure(figsize=(10, 6))
    time_ratios = [transformer_times[i]/mamba_times[i] for i in range(len(lengths))]
    memory_ratios = [transformer_memory[i]/mamba_memory[i] for i in range(len(lengths))]
    
    plt.plot(lengths, time_ratios, 'o-', label='Time Ratio (Transformer/Mamba)')
    plt.plot(lengths, memory_ratios, 's-', label='Memory Ratio (Transformer/Mamba)')
    plt.axhline(y=1, color='r', linestyle='--')
    plt.xlabel('Sequence Length')
    plt.ylabel('Ratio (Transformer/Mamba)')
    plt.title('Performance Ratio vs Sequence Length\n(Higher means Mamba is more efficient)')
    plt.legend()
    plt.grid(True)
    plt.savefig("results/figures/performance_ratio.png")
    plt.close()
    
    # 保存数值结果到CSV
    results_df = pd.DataFrame({
        'Metric': metrics,
        'Mamba': [mamba_results[m] for m in metrics],
        'Transformer': [transformer_results[m] for m in metrics],
        'Ratio (Transformer/Mamba)': [transformer_results[m]/mamba_results[m] if mamba_results[m] > 0 else 0 for m in metrics]
    })
    results_df.to_csv("results/basic_metrics.csv", index=False)
    
    # 长序列结果
    long_seq_df = pd.DataFrame({
        'Length': lengths,
        'Mamba_Time': mamba_times,
        'Transformer_Time': transformer_times,
        'Time_Ratio': time_ratios,
        'Mamba_Memory': mamba_memory,
        'Transformer_Memory': transformer_memory,
        'Memory_Ratio': memory_ratios
    })
    long_seq_df.to_csv("results/long_sequence_metrics.csv", index=False)

In [7]:
# 详细的模型比较函数
def compare_detailed_results(mamba_results, transformer_results):
    """
    比较两个模型的详细结果，生成综合报告和可视化
    
    参数:
        mamba_results: Mamba模型的结果字典
        transformer_results: Transformer模型的结果字典
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    from sklearn.metrics import roc_curve, auc
    
    # 创建输出目录
    os.makedirs("results/comparison", exist_ok=True)
    
    print("\n" + "="*50)
    print(f"DETAILED MODEL COMPARISON")
    print("="*50)
    
    # 1. 基本性能指标
    metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"]
    metric_names = ["Accuracy", "Precision", "Recall", "F1 Score", "ROC AUC"]
    
    print("\nCLASSIFICATION PERFORMANCE")
    print("-"*30)
    
    # 创建表格数据
    comparison_data = {
        "Metric": metric_names,
        "Mamba": [mamba_results.get(m, 0) for m in metrics],
        "Transformer": [transformer_results.get(m, 0) for m in metrics],
    }
    
    # 计算差异和比率
    comparison_data["Difference"] = [
        comparison_data["Mamba"][i] - comparison_data["Transformer"][i] 
        for i in range(len(metrics))
    ]
    
    comparison_data["Ratio (M/T)"] = [
        comparison_data["Mamba"][i] / comparison_data["Transformer"][i] if comparison_data["Transformer"][i] > 0 else 0
        for i in range(len(metrics))
    ]
    
    # 打印比较结果
    for i, metric in enumerate(metric_names):
        print(f"{metric}:")
        print(f"  Mamba: {comparison_data['Mamba'][i]:.4f}")
        print(f"  Transformer: {comparison_data['Transformer'][i]:.4f}")
        print(f"  Difference: {comparison_data['Difference'][i]:.4f}")
        print(f"  Ratio (M/T): {comparison_data['Ratio (M/T)'][i]:.4f}")
        print()
    
    # 2. 资源和效率指标
    efficiency_metrics = ["training_time", "inference_speed", "memory_usage"]
    efficiency_names = ["Training Time (s)", "Inference Speed (s/sample)", "Memory Usage (MB)"]
    
    print("\nEFFICIENCY METRICS")
    print("-"*30)
    
    efficiency_data = {
        "Metric": efficiency_names,
        "Mamba": [mamba_results.get(m, 0) for m in efficiency_metrics],
        "Transformer": [transformer_results.get(m, 0) for m in efficiency_metrics],
    }
    
    # 计算比率 - 注意对于效率指标，较小值更好，所以计算T/M
    efficiency_data["Ratio (T/M)"] = [
        efficiency_data["Transformer"][i] / efficiency_data["Mamba"][i] if efficiency_data["Mamba"][i] > 0 else 0
        for i in range(len(efficiency_metrics))
    ]
    
    # 打印效率比较结果
    for i, metric in enumerate(efficiency_names):
        print(f"{metric}:")
        print(f"  Mamba: {efficiency_data['Mamba'][i]:.4f}")
        print(f"  Transformer: {efficiency_data['Transformer'][i]:.4f}")
        print(f"  Ratio (T/M): {efficiency_data['Ratio (T/M)'][i]:.4f}")
        print(f"  {'Mamba' if efficiency_data['Ratio (T/M)'][i] > 1 else 'Transformer'} is more efficient")
        print()
    
    # 3. 长序列性能比较
    if "long_sequence_results" in mamba_results and "long_sequence_results" in transformer_results:
        print("\nLONG SEQUENCE PERFORMANCE")
        print("-"*30)
        
        # 获取长度数据
        lengths = list(mamba_results["long_sequence_results"].keys())
        
        # 收集时间和内存数据
        mamba_times = []
        transformer_times = []
        mamba_memory = []
        transformer_memory = []
        time_ratios = []
        memory_ratios = []
        
        for length in lengths:
            mamba_time = mamba_results["long_sequence_results"][length]["time"]
            transformer_time = transformer_results["long_sequence_results"][length]["time"]
            mamba_mem = mamba_results["long_sequence_results"][length]["memory"]
            transformer_mem = transformer_results["long_sequence_results"][length]["memory"]
            
            time_ratio = transformer_time / mamba_time if mamba_time > 0 else 0
            memory_ratio = transformer_mem / mamba_mem if mamba_mem > 0 else 0
            
            mamba_times.append(mamba_time)
            transformer_times.append(transformer_time)
            mamba_memory.append(mamba_mem)
            transformer_memory.append(transformer_mem)
            time_ratios.append(time_ratio)
            memory_ratios.append(memory_ratio)
            
            print(f"Sequence Length {length}:")
            print(f"  Time (s): Mamba = {mamba_time:.5f}, Transformer = {transformer_time:.5f}, Ratio (T/M) = {time_ratio:.2f}x")
            print(f"  Memory (MB): Mamba = {mamba_mem:.2f}, Transformer = {transformer_mem:.2f}, Ratio (T/M) = {memory_ratio:.2f}x")
            print(f"  {'Mamba' if time_ratio > 1 else 'Transformer'} is faster")
            print(f"  {'Mamba' if memory_ratio > 1 else 'Transformer'} is more memory efficient")
            print()
    
    # 创建可视化
    
    # 1. 分类性能条形图
    plt.figure(figsize=(12, 6))
    x = np.arange(len(metric_names))
    width = 0.35
    
    plt.bar(x - width/2, comparison_data["Mamba"], width, label='Mamba', color='blue', alpha=0.7)
    plt.bar(x + width/2, comparison_data["Transformer"], width, label='Transformer', color='red', alpha=0.7)
    
    plt.xlabel('Metrics')
    plt.ylabel('Score')
    plt.title('Classification Performance Comparison')
    plt.xticks(x, metric_names)
    plt.legend()
    
    # 添加数值标签
    for i, v in enumerate(comparison_data["Mamba"]):
        plt.text(i - width/2, v + 0.02, f"{v:.3f}", ha='center')
    
    for i, v in enumerate(comparison_data["Transformer"]):
        plt.text(i + width/2, v + 0.02, f"{v:.3f}", ha='center')
    
    plt.ylim(0, max(max(comparison_data["Mamba"]), max(comparison_data["Transformer"])) * 1.2)
    plt.tight_layout()
    plt.savefig("results/comparison/classification_metrics.png")
    plt.close()
    
    # 2. 效率指标条形图
    plt.figure(figsize=(12, 6))
    x = np.arange(len(efficiency_names))
    
    # 归一化效率指标以便在同一图表上显示
    max_values = [max(efficiency_data["Mamba"][i], efficiency_data["Transformer"][i]) for i in range(len(efficiency_metrics))]
    norm_mamba = [efficiency_data["Mamba"][i] / max_values[i] for i in range(len(efficiency_metrics))]
    norm_transformer = [efficiency_data["Transformer"][i] / max_values[i] for i in range(len(efficiency_metrics))]
    
    plt.bar(x - width/2, norm_mamba, width, label='Mamba', color='blue', alpha=0.7)
    plt.bar(x + width/2, norm_transformer, width, label='Transformer', color='red', alpha=0.7)
    
    plt.xlabel('Metrics')
    plt.ylabel('Normalized Value (lower is better)')
    plt.title('Efficiency Metrics Comparison (Normalized)')
    plt.xticks(x, efficiency_names)
    plt.legend()
    
    # 添加原始数值标签
    for i, v in enumerate(efficiency_data["Mamba"]):
        plt.text(i - width/2, norm_mamba[i] + 0.05, f"{v:.2f}", ha='center')
    
    for i, v in enumerate(efficiency_data["Transformer"]):
        plt.text(i + width/2, norm_transformer[i] + 0.05, f"{v:.2f}", ha='center')
    
    plt.ylim(0, 1.3)
    plt.tight_layout()
    plt.savefig("results/comparison/efficiency_metrics.png")
    plt.close()
    
    # 3. 长序列性能图
    if "long_sequence_results" in mamba_results and "long_sequence_results" in transformer_results:
        # 时间对比
        plt.figure(figsize=(10, 6))
        plt.plot(lengths, mamba_times, 'o-', label='Mamba', color='blue')
        plt.plot(lengths, transformer_times, 's-', label='Transformer', color='red')
        plt.xlabel('Sequence Length')
        plt.ylabel('Inference Time (s)')
        plt.title('Inference Time vs Sequence Length')
        plt.legend()
        plt.grid(True)
        plt.savefig("results/comparison/sequence_length_time.png")
        plt.close()
        
        # 内存对比
        plt.figure(figsize=(10, 6))
        plt.plot(lengths, mamba_memory, 'o-', label='Mamba', color='blue')
        plt.plot(lengths, transformer_memory, 's-', label='Transformer', color='red')
        plt.xlabel('Sequence Length')
        plt.ylabel('Memory Usage (MB)')
        plt.title('Memory Usage vs Sequence Length')
        plt.legend()
        plt.grid(True)
        plt.savefig("results/comparison/sequence_length_memory.png")
        plt.close()
        
        # 性能比率
        plt.figure(figsize=(10, 6))
        plt.plot(lengths, time_ratios, 'o-', label='Time Ratio (T/M)', color='blue')
        plt.plot(lengths, memory_ratios, 's-', label='Memory Ratio (T/M)', color='red')
        plt.axhline(y=1, color='gray', linestyle='--')
        plt.xlabel('Sequence Length')
        plt.ylabel('Ratio (Transformer/Mamba)')
        plt.title('Performance Ratio vs Sequence Length\n(Higher means Mamba is more efficient)')
        plt.legend()
        plt.grid(True)
        plt.savefig("results/comparison/performance_ratio.png")
        plt.close()
    
    # 4. ROC曲线比较
    if all(k in mamba_results for k in ["fpr", "tpr", "roc_auc"]) and all(k in transformer_results for k in ["fpr", "tpr", "roc_auc"]):
        plt.figure(figsize=(10, 8))
        plt.plot(mamba_results["fpr"], mamba_results["tpr"], 
                 color='blue', lw=2, 
                 label=f'Mamba ROC (AUC = {mamba_results["roc_auc"]:.4f})')
        plt.plot(transformer_results["fpr"], transformer_results["tpr"], 
                 color='red', lw=2, 
                 label=f'Transformer ROC (AUC = {transformer_results["roc_auc"]:.4f})')
        plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curves Comparison')
        plt.legend(loc="lower right")
        plt.savefig("results/comparison/roc_comparison.png")
        plt.close()
    
    # 创建混淆矩阵比较（如果可用）
    if "confusion_matrix" in mamba_results and "confusion_matrix" in transformer_results:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Mamba混淆矩阵
        im1 = ax1.imshow(mamba_results["confusion_matrix"], interpolation='nearest', cmap=plt.cm.Blues)
        ax1.set_title('Mamba Confusion Matrix')
        ax1.set_xlabel('Predicted Label')
        ax1.set_ylabel('True Label')
        
        # 添加文本
        for i in range(mamba_results["confusion_matrix"].shape[0]):
            for j in range(mamba_results["confusion_matrix"].shape[1]):
                ax1.text(j, i, format(mamba_results["confusion_matrix"][i, j], 'd'),
                       ha="center", va="center",
                       color="white" if mamba_results["confusion_matrix"][i, j] > mamba_results["confusion_matrix"].max() / 2 else "black")
        
        # Transformer混淆矩阵
        im2 = ax2.imshow(transformer_results["confusion_matrix"], interpolation='nearest', cmap=plt.cm.Reds)
        ax2.set_title('Transformer Confusion Matrix')
        ax2.set_xlabel('Predicted Label')
        ax2.set_ylabel('True Label')
        
        # 添加文本
        for i in range(transformer_results["confusion_matrix"].shape[0]):
            for j in range(transformer_results["confusion_matrix"].shape[1]):
                ax2.text(j, i, format(transformer_results["confusion_matrix"][i, j], 'd'),
                       ha="center", va="center",
                       color="white" if transformer_results["confusion_matrix"][i, j] > transformer_results["confusion_matrix"].max() / 2 else "black")
        
        plt.tight_layout()
        plt.savefig("results/comparison/confusion_matrices.png")
        plt.close()
    
    # 保存比较结果到CSV
    pd.DataFrame(comparison_data).to_csv("results/comparison/classification_metrics.csv", index=False)
    pd.DataFrame(efficiency_data).to_csv("results/comparison/efficiency_metrics.csv", index=False)
    
    # 如果有长序列数据，保存到CSV
    if "long_sequence_results" in mamba_results and "long_sequence_results" in transformer_results:
        long_seq_df = pd.DataFrame({
            'Length': lengths,
            'Mamba_Time': mamba_times,
            'Transformer_Time': transformer_times,
            'Time_Ratio': time_ratios,
            'Mamba_Memory': mamba_memory,
            'Transformer_Memory': transformer_memory,
            'Memory_Ratio': memory_ratios
        })
        long_seq_df.to_csv("results/comparison/long_sequence_metrics.csv", index=False)
    
    # 保存详细比较报告到文本文件
    with open("results/comparison/detailed_comparison.txt", "w") as f:
        f.write("DETAILED MODEL COMPARISON REPORT\n")
        f.write("="*50 + "\n\n")
        
        # 1. 分类性能
        f.write("CLASSIFICATION PERFORMANCE\n")
        f.write("-"*30 + "\n\n")
        
        for i, metric in enumerate(metric_names):
            f.write(f"{metric}:\n")
            f.write(f"  Mamba: {comparison_data['Mamba'][i]:.4f}\n")
            f.write(f"  Transformer: {comparison_data['Transformer'][i]:.4f}\n")
            f.write(f"  Difference: {comparison_data['Difference'][i]:.4f}\n")
            f.write(f"  Ratio (M/T): {comparison_data['Ratio (M/T)'][i]:.4f}\n\n")
        
        # 2. 效率指标
        f.write("\nEFFICIENCY METRICS\n")
        f.write("-"*30 + "\n\n")
        
        for i, metric in enumerate(efficiency_names):
            f.write(f"{metric}:\n")
            f.write(f"  Mamba: {efficiency_data['Mamba'][i]:.4f}\n")
            f.write(f"  Transformer: {efficiency_data['Transformer'][i]:.4f}\n")
            f.write(f"  Ratio (T/M): {efficiency_data['Ratio (T/M)'][i]:.4f}\n")
            f.write(f"  {'Mamba' if efficiency_data['Ratio (T/M)'][i] > 1 else 'Transformer'} is more efficient\n\n")
        
        # 3. 长序列性能
        if "long_sequence_results" in mamba_results and "long_sequence_results" in transformer_results:
            f.write("\nLONG SEQUENCE PERFORMANCE\n")
            f.write("-"*30 + "\n\n")
            
            for i, length in enumerate(lengths):
                f.write(f"Sequence Length {length}:\n")
                f.write(f"  Time (s): Mamba = {mamba_times[i]:.5f}, Transformer = {transformer_times[i]:.5f}, Ratio (T/M) = {time_ratios[i]:.2f}x\n")
                f.write(f"  Memory (MB): Mamba = {mamba_memory[i]:.2f}, Transformer = {transformer_memory[i]:.2f}, Ratio (T/M) = {memory_ratios[i]:.2f}x\n")
                f.write(f"  {'Mamba' if time_ratios[i] > 1 else 'Transformer'} is faster\n")
                f.write(f"  {'Mamba' if memory_ratios[i] > 1 else 'Transformer'} is more memory efficient\n\n")
        
        # 4. 总结
        f.write("\nSUMMARY\n")
        f.write("-"*30 + "\n\n")
        
        # 分类性能总结
        avg_accuracy_diff = comparison_data["Difference"][0]
        if avg_accuracy_diff > 0.02:
            f.write(f"Mamba outperforms Transformer in classification accuracy by {avg_accuracy_diff:.4f}\n")
        elif avg_accuracy_diff < -0.02:
            f.write(f"Transformer outperforms Mamba in classification accuracy by {-avg_accuracy_diff:.4f}\n")
        else:
            f.write(f"Mamba and Transformer have similar classification accuracy (difference: {avg_accuracy_diff:.4f})\n")
        
        # 效率总结
        avg_time_ratio = efficiency_data["Ratio (T/M)"][0]
        avg_memory_ratio = efficiency_data["Ratio (T/M)"][2]
        
        if avg_time_ratio > 1.1:
            f.write(f"Mamba is {avg_time_ratio:.2f}x faster in training than Transformer\n")
        elif avg_time_ratio < 0.9:
            f.write(f"Transformer is {1/avg_time_ratio:.2f}x faster in training than Mamba\n")
        else:
            f.write(f"Mamba and Transformer have similar training speed\n")
        
        if avg_memory_ratio > 1.1:
            f.write(f"Mamba uses {1/avg_memory_ratio:.2f}x less memory than Transformer\n")
        elif avg_memory_ratio < 0.9:
            f.write(f"Transformer uses {avg_memory_ratio:.2f}x less memory than Mamba\n")
        else:
            f.write(f"Mamba and Transformer have similar memory usage\n")
        
        # 长序列总结
        if "long_sequence_results" in mamba_results and "long_sequence_results" in transformer_results:
            avg_long_time_ratio = sum(time_ratios) / len(time_ratios)
            avg_long_memory_ratio = sum(memory_ratios) / len(memory_ratios)
            
            if avg_long_time_ratio > 1.1:
                f.write(f"Mamba is on average {avg_long_time_ratio:.2f}x faster for long sequences than Transformer\n")
            elif avg_long_time_ratio < 0.9:
                f.write(f"Transformer is on average {1/avg_long_time_ratio:.2f}x faster for long sequences than Mamba\n")
            else:
                f.write(f"Mamba and Transformer have similar speed for long sequences\n")
            
            if avg_long_memory_ratio > 1.1:
                f.write(f"Mamba uses on average {1/avg_long_memory_ratio:.2f}x less memory for long sequences than Transformer\n")
            elif avg_long_memory_ratio < 0.9:
                f.write(f"Transformer uses on average {avg_long_memory_ratio:.2f}x less memory for long sequences than Mamba\n")
            else:
                f.write(f"Mamba and Transformer have similar memory usage for long sequences\n")
        
        # 最终结论
        f.write("\nFINAL CONCLUSION:\n")
        if comparison_data["Difference"][0] > 0 and avg_time_ratio > 1:
            f.write("Mamba appears to be both more accurate and more efficient than Transformer for this task.\n")
        elif comparison_data["Difference"][0] < 0 and avg_time_ratio < 1:
            f.write("Transformer appears to be both more accurate and more efficient than Mamba for this task.\n")
        elif comparison_data["Difference"][0] > 0 and avg_time_ratio < 1:
            f.write("Mamba appears to be more accurate but less efficient than Transformer for this task.\n")
        elif comparison_data["Difference"][0] < 0 and avg_time_ratio > 1:
            f.write("Transformer appears to be more accurate but less efficient than Mamba for this task.\n")
        else:
            f.write("Both models show comparable performance with different trade-offs.\n")
    
    print("\nDetailed comparison completed. Results saved to 'results/comparison/' directory.")
    
    return {
        "classification_metrics": comparison_data,
        "efficiency_metrics": efficiency_data,
        "long_sequence_metrics": {"lengths": lengths, "time_ratios": time_ratios, "memory_ratios": memory_ratios} 
        if "long_sequence_results" in mamba_results and "long_sequence_results" in transformer_results else None
    }

In [8]:
# 运行完整实验的函数
def run_complete_experiment(test_mode=False):
    """
    运行完整的模型比较实验
    
    参数:
        test_mode: 是否运行测试模式(默认: False)
    
    返回:
        dict: 包含实验结果的字典
    """
    import time
    import os
    
    # 记录开始时间
    experiment_start_time = time.time()
    
    print(f"\n{'='*30} STARTING EXPERIMENT {'(TEST MODE)' if test_mode else ''} {'='*30}")
    
    # 创建结果目录
    os.makedirs("results", exist_ok=True)
    os.makedirs("results/mamba", exist_ok=True)
    os.makedirs("results/transformer", exist_ok=True)
    os.makedirs("results/comparison", exist_ok=True)
    
    # 训练Mamba模型
    print("\n[1/3] Training and evaluating Mamba model...")
    try:
        mamba_results = train_evaluate_mamba(test_mode)
        print("Mamba training and evaluation completed successfully!")
    except Exception as e:
        print(f"Error in Mamba experiment: {e}")
        import traceback
        traceback.print_exc()
        mamba_results = None
    
    # 训练Transformer模型
    print("\n[2/3] Training and evaluating Transformer model...")
    try:
        transformer_results = train_evaluate_transformer(test_mode)
        print("Transformer training and evaluation completed successfully!")
    except Exception as e:
        print(f"Error in Transformer experiment: {e}")
        import traceback
        traceback.print_exc()
        transformer_results = None
    
    # 比较结果
    if mamba_results is not None and transformer_results is not None:
        print("\n[3/3] Comparing model results...")
        try:
            comparison_results = compare_detailed_results(mamba_results, transformer_results)
            print("Model comparison completed successfully!")
        except Exception as e:
            print(f"Error in model comparison: {e}")
            import traceback
            traceback.print_exc()
            comparison_results = None
    else:
        print("\n[3/3] Skipping comparison as one or both models failed.")
        comparison_results = None
    
    # 计算总实验时间
    total_time = time.time() - experiment_start_time
    hours, remainder = divmod(total_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    print(f"\n{'='*30} EXPERIMENT COMPLETED {'='*30}")
    print(f"Total experiment time: {int(hours)}h {int(minutes)}m {seconds:.2f}s")
    
    # 保存总结果
    with open("results/experiment_summary.txt", "w") as f:
        f.write("MAMBA VS TRANSFORMER EXPERIMENT SUMMARY\n")
        f.write("="*50 + "\n\n")
        
        f.write(f"Experiment mode: {'TEST' if test_mode else 'FULL'}\n")
        f.write(f"Total experiment time: {int(hours)}h {int(minutes)}m {seconds:.2f}s\n\n")
        
        if mamba_results is not None:
            f.write("Mamba model: SUCCESS\n")
            f.write(f"  - Accuracy: {mamba_results.get('accuracy', 'N/A'):.4f}\n")
            f.write(f"  - F1 Score: {mamba_results.get('f1', 'N/A'):.4f}\n")
            f.write(f"  - Training time: {mamba_results.get('training_time', 'N/A'):.2f}s\n\n")
        else:
            f.write("Mamba model: FAILED\n\n")
        
        if transformer_results is not None:
            f.write("Transformer model: SUCCESS\n")
            f.write(f"  - Accuracy: {transformer_results.get('accuracy', 'N/A'):.4f}\n")
            f.write(f"  - F1 Score: {transformer_results.get('f1', 'N/A'):.4f}\n")
            f.write(f"  - Training time: {transformer_results.get('training_time', 'N/A'):.2f}s\n\n")
        else:
            f.write("Transformer model: FAILED\n\n")
        
        if comparison_results is not None:
            f.write("Models comparison: SUCCESS\n")
            f.write("  - See 'results/comparison/' directory for detailed reports and visualizations\n\n")
        else:
            f.write("Models comparison: FAILED or SKIPPED\n\n")
        
        f.write("\nResults directories:\n")
        f.write("  - Mamba results: ./results/mamba/\n")
        f.write("  - Transformer results: ./results/transformer/\n")
        f.write("  - Comparison results: ./results/comparison/\n")
    
    # 返回结果
    return {
        "mamba_results": mamba_results,
        "transformer_results": transformer_results,
        "comparison_results": comparison_results,
        "total_time": total_time
    }

In [9]:
# 用法示例 (取消注释以运行)
#experiment_results = run_complete_experiment(test_mode=True)  # 测试模式
experiment_results = run_complete_experiment(test_mode=False)  # 完整模式



[1/3] Training and evaluating Mamba model...

Running Mamba experiment  (FROM SCRATCH)
使用本地缓存的tokenizer
Mamba model parameters: 54,369,282


Step,Training Loss,Validation Loss,Accuracy,Memory Usage
250,0.7146,0.708354,0.490826,2115.070312
500,0.6802,0.606504,0.649083,2165.082031
750,0.4649,0.470075,0.783257,2157.386719
1000,0.3463,0.431795,0.793578,2157.519531
1250,0.2811,0.502349,0.779817,2159.078125
1500,0.2446,0.46117,0.788991,2159.347656
1750,0.2367,0.473358,0.809633,2159.394531
2000,0.2171,0.507201,0.806193,2159.441406
2250,0.1732,0.456599,0.830275,2160.222656
2500,0.1434,0.549011,0.795872,2160.300781


Step 250: Accuracy = 0.4908
New best accuracy: 0.4908 at step 250
已手动保存最佳模型到 ./results/mamba/full_scratch/best_model
Step 500: Recent average loss = 0.6974
Step 500: Accuracy = 0.6491
New best accuracy: 0.6491 at step 500
已手动保存最佳模型到 ./results/mamba/full_scratch/best_model
Step 750: Accuracy = 0.7833
New best accuracy: 0.7833 at step 750
已手动保存最佳模型到 ./results/mamba/full_scratch/best_model
Step 1000: Recent average loss = 0.5515
Step 1000: Accuracy = 0.7936
New best accuracy: 0.7936 at step 1000
已手动保存最佳模型到 ./results/mamba/full_scratch/best_model
Step 1250: Accuracy = 0.7798
Step 1500: Recent average loss = 0.4553
Step 1500: Accuracy = 0.7890
Step 1750: Accuracy = 0.8096
New best accuracy: 0.8096 at step 1750
已手动保存最佳模型到 ./results/mamba/full_scratch/best_model
Step 2000: Recent average loss = 0.3982
Step 2000: Accuracy = 0.8062
Step 2250: Accuracy = 0.8303
New best accuracy: 0.8303 at step 2250
已手动保存最佳模型到 ./results/mamba/full_scratch/best_model
Step 2500: Recent average loss = 0.3502
Step 2

Step 6312: Accuracy = 0.8303
Testing inference speed...
Testing long sequence performance...
Testing sequence length: 128
Testing sequence length: 256
Testing sequence length: 512
Testing sequence length: 1024
Testing sequence length: 2048
Testing sequence length: 4096
Testing sequence length: 8192
Testing sequence length: 16384
Mamba training and evaluation completed successfully!

[2/3] Training and evaluating Transformer model...

Running Transformer experiment  (FROM SCRATCH)
使用本地缓存的tokenizer
从头创建优化的BERT模型...
Transformer model parameters: 66,956,546


Step,Training Loss,Validation Loss,Accuracy,Memory Usage
250,0.9506,0.70235,0.489679,2294.445312
500,0.8382,0.707287,0.509174,2296.507812
750,0.8179,0.733286,0.509174,2296.507812
1000,0.7965,0.702022,0.509174,2296.5625
1250,0.7702,0.768985,0.544725,2296.660156
1500,0.7418,0.614473,0.692661,2296.667969
1750,0.6443,0.606957,0.748853,2296.667969
2000,0.5422,0.597611,0.754587,2296.667969
2250,0.447,0.557777,0.776376,2296.667969
2500,0.3962,0.647779,0.78211,2296.667969


Step 250: Accuracy = 0.4897
New best accuracy: 0.4897 at step 250
已手动保存最佳模型到 ./results/transformer/full_scratch/best_model
Step 500: Recent average loss = 0.8944
Step 500: Accuracy = 0.5092
New best accuracy: 0.5092 at step 500
已手动保存最佳模型到 ./results/transformer/full_scratch/best_model
Step 750: Accuracy = 0.5092
Step 1000: Recent average loss = 0.8508
Step 1000: Accuracy = 0.5092
Step 1250: Accuracy = 0.5447
New best accuracy: 0.5447 at step 1250
已手动保存最佳模型到 ./results/transformer/full_scratch/best_model
Step 1500: Recent average loss = 0.8192
Step 1500: Accuracy = 0.6927
New best accuracy: 0.6927 at step 1500
已手动保存最佳模型到 ./results/transformer/full_scratch/best_model
Step 1750: Accuracy = 0.7489
New best accuracy: 0.7489 at step 1750
已手动保存最佳模型到 ./results/transformer/full_scratch/best_model
Step 2000: Recent average loss = 0.7627
Step 2000: Accuracy = 0.7546
New best accuracy: 0.7546 at step 2000
已手动保存最佳模型到 ./results/transformer/full_scratch/best_model
Step 2250: Accuracy = 0.7764
New best 

Step 6312: Accuracy = 0.8142
Testing inference speed...
Testing long sequence performance...
Testing sequence length: 128
Testing sequence length: 256
Testing sequence length: 512
Testing sequence length: 1024
Testing sequence length: 2048
Testing sequence length: 4096
Testing sequence length: 8192
Testing sequence length: 16384
Transformer training and evaluation completed successfully!

[3/3] Comparing model results...

DETAILED MODEL COMPARISON

CLASSIFICATION PERFORMANCE
------------------------------
Accuracy:
  Mamba: 0.8303
  Transformer: 0.8142
  Difference: 0.0161
  Ratio (M/T): 1.0197

Precision:
  Mamba: 0.8318
  Transformer: 0.8190
  Difference: 0.0128
  Ratio (M/T): 1.0157

Recall:
  Mamba: 0.8356
  Transformer: 0.8153
  Difference: 0.0203
  Ratio (M/T): 1.0249

F1 Score:
  Mamba: 0.8337
  Transformer: 0.8172
  Difference: 0.0166
  Ratio (M/T): 1.0203

ROC AUC:
  Mamba: 0.9025
  Transformer: 0.8878
  Difference: 0.0147
  Ratio (M/T): 1.0166


EFFICIENCY METRICS
-----------