In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup, BertModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.optim import AdamW
import numpy as np
import os
import random
import time
import datetime
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from matplotlib.font_manager import FontProperties
import matplotlib
matplotlib.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

class TextDataset(Dataset):
    def __init__(self, sentences, labels, tokenizer, max_length=128):
        self.encodings = tokenizer(sentences, truncation=True, padding='max_length', max_length=max_length)
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

def set_seed(seed_value=42):
    """设置随机种子，确保结果可复现"""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

def load_data(file_path, tokenizer, max_length=128):
    """加载并处理数据，不使用pandas"""
    try:
        sentences = []
        labels = []

        with open(file_path, 'r', encoding='utf-8') as f:
            # 跳过标题行
            next(f)

            for line in f:
                line = line.strip()
                if not line:
                    continue

                # 假设CSV格式为: sentence,label
                parts = line.split(',')
                if len(parts) >= 2:
                    sentence = ','.join(parts[:-1])  # 处理句子中可能包含的逗号
                    label = int(parts[-1])
                    sentences.append(sentence)
                    labels.append(label)

        dataset = TextDataset(sentences, labels, tokenizer, max_length)
        return dataset
    except Exception as e:
        print(f"数据加载错误: {e}")
        return None

def create_weighted_sampler(dataset):
    """创建加权采样器以处理类别不平衡"""
    labels = np.array(dataset.labels)
    class_counts = np.bincount(labels)
    class_weights = 1. / torch.tensor(class_counts, dtype=torch.float)
    sample_weights = class_weights[labels]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    return sampler

def create_data_loader(dataset, batch_size, sampler_type='random'):
    """创建数据加载器"""
    if sampler_type == 'weighted':
        sampler = create_weighted_sampler(dataset)
        return DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    elif sampler_type == 'random':
        sampler = torch.utils.data.RandomSampler(dataset)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)

    return DataLoader(dataset, sampler=sampler, batch_size=batch_size)

def calculate_metrics_multiclass_improved(preds, labels):
    """计算多分类任务的评估指标"""
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    # 计算每个类别的精确率、召回率和F1分数
    metrics = {'accuracy': np.sum(pred_flat == labels_flat) / len(labels_flat)}

    classes = np.unique(np.concatenate([pred_flat, labels_flat]))

    # 计算宏平均指标
    precision_sum = 0
    recall_sum = 0
    f1_sum = 0
    n_classes = len(classes)

    for cls in classes:
        tp = np.sum((pred_flat == cls) & (labels_flat == cls))
        fp = np.sum((pred_flat == cls) & (labels_flat != cls))
        fn = np.sum((pred_flat != cls) & (labels_flat == cls))

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        precision_sum += precision
        recall_sum += recall
        f1_sum += f1

    metrics['precision'] = precision_sum / n_classes
    metrics['recall'] = recall_sum / n_classes
    metrics['f1'] = f1_sum / n_classes

    return metrics

# 增强的BERT分类模型
class EnhancedBertForSequenceClassification(nn.Module):
    def __init__(self, bert_model_name, num_labels):
        super(EnhancedBertForSequenceClassification, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.num_labels = num_labels  # 存储标签数量

        # 冻结BERT底层参数以专注于微调顶层
        for param in list(self.bert.parameters())[:-4*12]:  # 冻结除最后4层外的所有层
            param.requires_grad = False

        # 丰富的分类头部
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 768),
            nn.LayerNorm(768),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(768, 384),
            nn.LayerNorm(384),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(384, num_labels)
        )

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            # 计算每个类别的权重，反比于其频率
            if hasattr(self, 'class_weights') and self.class_weights is not None:
                loss_fct = nn.CrossEntropyLoss(weight=self.class_weights.to(labels.device))
            else:
                loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def set_class_weights(self, class_weights):
        """设置类别权重用于损失函数"""
        self.class_weights = class_weights

def train_model(model, train_dataloader, val_dataloader, optimizer, scheduler, device, epochs, save_path):
    """训练模型"""
    total_t0 = time.time()
    best_val_loss = float('inf')
    best_accuracy = 0.0

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # 添加早停变量
    patience = 3
    early_stop_counter = 0

    # 创建梯度缩放器用于混合精度训练
    scaler = GradScaler()

    # 记录训练过程中的指标
    history = {
        'train_loss': [],
        'val_loss': [],
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': []
    }

    for epoch_i in range(0, epochs):
        print(f'======== Epoch {epoch_i + 1} / {epochs} ========')
        print('Training...')

        t0 = time.time()
        total_train_loss = 0

        model.train()

        for step, batch in enumerate(train_dataloader):
            if step % 40 == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print(f'  Batch {step} of {len(train_dataloader)}. Elapsed: {elapsed}')

            b_input_ids = batch['input_ids'].to(device)
            b_input_mask = batch['attention_mask'].to(device)
            b_labels = batch['labels'].to(device)

            model.zero_grad()

            # 使用混合精度训练
            with autocast():
                outputs = model(b_input_ids,
                                token_type_ids=None,
                                attention_mask=b_input_mask,
                                labels=b_labels)
                loss = outputs.loss

            total_train_loss += loss.item()

            # 缩放梯度并反向传播
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        avg_train_loss = total_train_loss / len(train_dataloader)
        training_time = format_time(time.time() - t0)

        print(f"  Average training loss: {avg_train_loss:.4f}")
        print(f"  Training epoch took: {training_time}")

        print("\nRunning Validation...")
        t0 = time.time()

        model.eval()

        total_eval_loss = 0
        all_preds = []
        all_labels = []

        for batch in val_dataloader:
            b_input_ids = batch['input_ids'].to(device)
            b_input_mask = batch['attention_mask'].to(device)
            b_labels = batch['labels'].to(device)

            with torch.no_grad():
                outputs = model(b_input_ids,
                                token_type_ids=None,
                                attention_mask=b_input_mask,
                                labels=b_labels)

            loss = outputs.loss
            total_eval_loss += loss.item()

            logits = outputs.logits
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()

            all_preds.append(logits)
            all_labels.append(label_ids)

        all_preds = np.concatenate(all_preds, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)

        avg_val_loss = total_eval_loss / len(val_dataloader)
        metrics = calculate_metrics_multiclass_improved(all_preds, all_labels)
        validation_time = format_time(time.time() - t0)

        # 记录当前epoch的指标
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['accuracy'].append(metrics['accuracy'])
        history['precision'].append(metrics['precision'])
        history['recall'].append(metrics['recall'])
        history['f1'].append(metrics['f1'])

        print(f"  Accuracy: {metrics['accuracy']:.4f}")
        print(f"  Validation Loss: {avg_val_loss:.4f}")
        print(f"  Precision: {metrics['precision']:.4f}")
        print(f"  Recall: {metrics['recall']:.4f}")
        print(f"  F1-Score: {metrics['f1']:.4f}")
        print(f"  Validation took: {validation_time}")

        # 验证后的早停检查 - 根据准确率和损失共同判断
        current_accuracy = metrics['accuracy']
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy
            early_stop_counter = 0
            model_path = os.path.join(save_path, 'best_model_acc.pt')
            torch.save(model.state_dict(), model_path)
            print(f"  Best accuracy model saved at: {model_path}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            if abs(best_accuracy - current_accuracy) <= 0.01:  # 如果精度相近，优先选择低损失模型
                early_stop_counter = 0
                model_path = os.path.join(save_path, 'best_model.pt')
                torch.save(model.state_dict(), model_path)
                print(f"  Best loss model saved at: {model_path}")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"Early stopping after {epoch_i + 1} epochs")
                break

    print(f"\nTraining complete! Total training took {format_time(time.time()-total_t0)}")
    print(f"Best validation accuracy: {best_accuracy:.4f}")

    # 保存并绘制训练历史
    plot_training_history(history, save_path)

    return best_val_loss, best_accuracy, history

def format_time(elapsed):
    """将时间格式化为 hh:mm:ss"""
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

def inspect_dataset(file_path):
    """检查数据集中的标签分布"""
    labels = []
    with open(file_path, 'r', encoding='utf-8') as f:
        next(f)  # 跳过标题行
        for line in f:
            parts = line.strip().split(',')
            if len(parts) >= 2:
                label = int(parts[-1])
                labels.append(label)

    unique_labels = np.unique(labels)
    counts = {label: labels.count(label) for label in unique_labels}

    print(f"文件 {file_path} 中的标签分布:")
    print(f"唯一标签值: {unique_labels}")
    print(f"标签计数: {counts}")

    return unique_labels, counts

def calculate_class_weights(counts, num_labels):
    """计算类别权重"""
    weights = torch.zeros(num_labels)
    total_samples = sum(counts.values())

    for label, count in counts.items():
        weights[label] = total_samples / (count * num_labels)

    return weights

def plot_training_history(history, save_path):
    """Plot training metrics curves"""
    epochs = range(1, len(history['train_loss']) + 1)

    # Create directory for plots if it doesn't exist
    plots_dir = os.path.join(save_path, 'plots')
    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir)

    # 1. Plot loss curves
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, 'loss_curve.pdf'), dpi=300, bbox_inches='tight')
    plt.close()

    # 2. Plot accuracy curve
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, history['accuracy'], 'g-', label='Accuracy')
    plt.title('Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, 'accuracy_curve.pdf'), dpi=300, bbox_inches='tight')
    plt.close()

    # 3. Plot precision, recall, and F1 score curves
    plt.figure(figsize=(12, 8))
    plt.plot(epochs, history['precision'], 'b-', label='Precision')
    plt.plot(epochs, history['recall'], 'r-', label='Recall')
    plt.plot(epochs, history['f1'], 'g-', label='F1-Score')
    plt.title('Validation Metrics')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, 'metrics_curve.pdf'), dpi=300, bbox_inches='tight')
    plt.close()

    # 4. Plot all metrics in one figure
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 1, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 1, 2)
    plt.plot(epochs, history['accuracy'], 'm-', label='Accuracy')
    plt.plot(epochs, history['precision'], 'c-', label='Precision')
    plt.plot(epochs, history['recall'], 'y-', label='Recall')
    plt.plot(epochs, history['f1'], 'g-', label='F1-Score')
    plt.title('Validation Metrics')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'all_metrics.pdf'), dpi=300, bbox_inches='tight')
    plt.close()

    print(f"训练历史图表已保存到: {plots_dir}")

    # 保存历史数据为CSV文件
    import csv
    csv_path = os.path.join(save_path, 'training_history.csv')
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'val_loss', 'accuracy', 'precision', 'recall', 'f1'])
        for i in range(len(epochs)):
            writer.writerow([epochs[i],
                            history['train_loss'][i],
                            history['val_loss'][i],
                            history['accuracy'][i],
                            history['precision'][i],
                            history['recall'][i],
                            history['f1'][i]])
    print(f"训练历史数据已保存到: {csv_path}")

def main():
    # 设置参数
    SEED = 42
    BATCH_SIZE = 32  # 增大批量大小
    LEARNING_RATE = 1e-5  # 使用更合适的学习率
    EPSILON = 1e-8
    EPOCHS = 15  # 增加训练轮次以便更好地学习
    MAX_LENGTH = 128
    SAVE_PATH = 'model_output'
    TRAIN_FILE = 'training.csv'
    VAL_FILE = 'validation.csv'
    USE_CLASS_WEIGHTS = True  # 使用类别权重
    USE_WEIGHTED_SAMPLER = True  # 使用加权采样器

    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    print("Loading data...")
    train_dataset = load_data(TRAIN_FILE, tokenizer, MAX_LENGTH)
    val_dataset = load_data(VAL_FILE, tokenizer, MAX_LENGTH)

    if train_dataset is None or val_dataset is None:
        print("数据加载失败，程序退出")
        return

    train_labels, train_counts = inspect_dataset(TRAIN_FILE)
    val_labels, val_counts = inspect_dataset(VAL_FILE)

    # 确保模型配置与数据匹配
    num_labels = max(max(train_labels), max(val_labels)) + 1
    print(f"检测到的最大标签值: {num_labels-1}，设置num_labels={num_labels}")

    # 使用增强的BERT分类模型
    model = EnhancedBertForSequenceClassification('bert-base-uncased', num_labels=num_labels)

    # 计算并设置类别权重
    if USE_CLASS_WEIGHTS:
        class_weights = calculate_class_weights(train_counts, num_labels)
        model.set_class_weights(class_weights)
        print(f"使用类别权重: {class_weights}")

    model.to(device)

    # 使用加权采样器处理类别不平衡
    if USE_WEIGHTED_SAMPLER:
        train_dataloader = create_data_loader(train_dataset, BATCH_SIZE, 'weighted')
        print("使用加权采样器处理类别不平衡")
    else:
        train_dataloader = create_data_loader(train_dataset, BATCH_SIZE, 'random')

    val_dataloader = create_data_loader(val_dataset, BATCH_SIZE, 'sequential')

    # 使用较低的初始学习率，增加权重衰减
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, eps=EPSILON, weight_decay=0.01)

    total_steps = len(train_dataloader) * EPOCHS

    # 修改学习率调度策略，添加预热步骤
    warmup_steps = int(0.1 * total_steps)  # 10%的预热步骤
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    print("Starting training...")
    best_val_loss, best_accuracy, history = train_model(model, train_dataloader, val_dataloader, optimizer, scheduler,
                                device, EPOCHS, SAVE_PATH)

    print(f"Training completed with best validation loss: {best_val_loss:.4f}")
    print(f"Best validation accuracy: {best_accuracy:.4f}")

    # 如果需要单独调用可视化函数，可以在这里添加
    # plot_training_history(history, SAVE_PATH)

if __name__ == "__main__":
    main()

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Loading data...
文件 training.csv 中的标签分布:
唯一标签值: [0 1 2 3 4 5]
标签计数: {np.int64(0): 4665, np.int64(1): 5362, np.int64(2): 1304, np.int64(3): 2159, np.int64(4): 1937, np.int64(5): 572}
文件 validation.csv 中的标签分布:
唯一标签值: [0 1 2 3 4 5]
标签计数: {np.int64(0): 549, np.int64(1): 704, np.int64(2): 178, np.int64(3): 275, np.int64(4): 212, np.int64(5): 81}
检测到的最大标签值: 5，设置num_labels=6


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

使用类别权重: tensor([0.5716, 0.4973, 2.0449, 1.2351, 1.3766, 4.6617])
使用加权采样器处理类别不平衡
Starting training...
Training...


  scaler = GradScaler()
  with autocast():


  Batch 40 of 500. Elapsed: 0:00:04
  Batch 80 of 500. Elapsed: 0:00:07
  Batch 120 of 500. Elapsed: 0:00:10
  Batch 160 of 500. Elapsed: 0:00:13
  Batch 200 of 500. Elapsed: 0:00:16
  Batch 240 of 500. Elapsed: 0:00:19
  Batch 280 of 500. Elapsed: 0:00:22
  Batch 320 of 500. Elapsed: 0:00:25
  Batch 360 of 500. Elapsed: 0:00:28
  Batch 400 of 500. Elapsed: 0:00:31
  Batch 440 of 500. Elapsed: 0:00:34
  Batch 480 of 500. Elapsed: 0:00:38
  Average training loss: 1.4729
  Training epoch took: 0:00:39

Running Validation...
  Accuracy: 0.2101
  Validation Loss: 1.5993
  Precision: 0.1658
  Recall: 0.4165
  F1-Score: 0.2106
  Validation took: 0:00:12
  Best accuracy model saved at: model_output/best_model_acc.pt
  Best loss model saved at: model_output/best_model.pt
Training...
  Batch 40 of 500. Elapsed: 0:00:03
  Batch 80 of 500. Elapsed: 0:00:06
  Batch 120 of 500. Elapsed: 0:00:10
  Batch 160 of 500. Elapsed: 0:00:13
  Batch 200 of 500. Elapsed: 0:00:16
  Batch 240 of 500. Elapsed: 0:

In [3]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
from transformers.modeling_outputs import SequenceClassifierOutput
import numpy as np
import os
import torch.nn as nn

# 复用与训练相同的数据集类
class TextDataset(Dataset):
    def __init__(self, sentences, labels, tokenizer, max_length=128):
        self.encodings = tokenizer(sentences, truncation=True, padding='max_length', max_length=max_length)
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# 复用与训练相同的增强BERT模型
class EnhancedBertForSequenceClassification(nn.Module):
    def __init__(self, bert_model_name, num_labels):
        super(EnhancedBertForSequenceClassification, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.num_labels = num_labels  # 存储标签数量

        # 丰富的分类头部
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 768),
            nn.LayerNorm(768),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(768, 384),
            nn.LayerNorm(384),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(384, num_labels)
        )

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def set_class_weights(self, class_weights):
        """设置类别权重用于损失函数"""
        self.class_weights = class_weights

def load_data(file_path, tokenizer, max_length=128):
    """加载并处理数据，不使用pandas"""
    try:
        sentences = []
        labels = []

        with open(file_path, 'r', encoding='utf-8') as f:
            # 跳过标题行
            next(f)

            for line in f:
                line = line.strip()
                if not line:
                    continue

                # 假设CSV格式为: sentence,label
                parts = line.split(',')
                if len(parts) >= 2:
                    sentence = ','.join(parts[:-1])  # 处理句子中可能包含的逗号
                    label = int(parts[-1])
                    sentences.append(sentence)
                    labels.append(label)

        dataset = TextDataset(sentences, labels, tokenizer, max_length)
        return dataset
    except Exception as e:
        print(f"数据加载错误: {e}")
        return None

def create_data_loader(dataset, batch_size, sampler_type='sequential'):
    """创建数据加载器"""
    sampler = torch.utils.data.SequentialSampler(dataset)
    return DataLoader(dataset, sampler=sampler, batch_size=batch_size)

def calculate_metrics_multiclass_improved(preds, labels):
    """计算多分类任务的评估指标"""
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    # 计算每个类别的精确率、召回率和F1分数
    metrics = {'accuracy': np.sum(pred_flat == labels_flat) / len(labels_flat)}

    classes = np.unique(np.concatenate([pred_flat, labels_flat]))

    # 计算宏平均指标
    precision_sum = 0
    recall_sum = 0
    f1_sum = 0
    n_classes = len(classes)

    for cls in classes:
        tp = np.sum((pred_flat == cls) & (labels_flat == cls))
        fp = np.sum((pred_flat == cls) & (labels_flat != cls))
        fn = np.sum((pred_flat != cls) & (labels_flat == cls))

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        precision_sum += precision
        recall_sum += recall
        f1_sum += f1

    metrics['precision'] = precision_sum / n_classes
    metrics['recall'] = recall_sum / n_classes
    metrics['f1'] = f1_sum / n_classes

    return metrics

def evaluate_model(model, dataloader, device):
    """评估模型性能"""
    model.eval()

    all_preds = []
    all_labels = []

    for batch in dataloader:
        b_input_ids = batch['input_ids'].to(device)
        b_input_mask = batch['attention_mask'].to(device)
        b_labels = batch['labels'].to(device)

        with torch.no_grad():
            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask)

        logits = outputs.logits
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        all_preds.append(logits)
        all_labels.append(label_ids)

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    return calculate_metrics_multiclass_improved(all_preds, all_labels)

def predict_text(model, tokenizer, text, device, max_length=128):
    """预测单个文本的情感"""
    model.eval()

    inputs = tokenizer(text, truncation=True, padding='max_length',
                      max_length=max_length, return_tensors='pt')

    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=1)
    confidence, predicted_class = torch.max(probs, dim=1)

    predicted_class = predicted_class.item()
    confidence = confidence.item()
    probs = probs.cpu().numpy()[0]

    return {
        'predicted_class': predicted_class,
        'confidence': confidence,
        'probabilities': probs
    }

def inspect_dataset(file_path):
    """检查数据集中的标签分布"""
    labels = []
    with open(file_path, 'r', encoding='utf-8') as f:
        next(f)  # 跳过标题行
        for line in f:
            parts = line.strip().split(',')
            if len(parts) >= 2:
                label = int(parts[-1])
                labels.append(label)

    unique_labels = np.unique(labels)
    counts = {label: labels.count(label) for label in unique_labels}

    print(f"文件 {file_path} 中的标签分布:")
    print(f"唯一标签值: {unique_labels}")
    print(f"标签计数: {counts}")

    return unique_labels, counts

def main():
    # 设置参数
    MODEL_PATH = 'model_output/best_model_acc.pt'  # 使用最高准确率的模型
    TEST_FILE = 'test.csv'
    MAX_LENGTH = 128
    BATCH_SIZE = 32

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # 加载数据集并检查标签分布
    print("Loading test data...")
    test_dataset = load_data(TEST_FILE, tokenizer, MAX_LENGTH)

    if test_dataset is None:
        print("测试数据加载失败，程序退出")
        return

    test_labels, test_counts = inspect_dataset(TEST_FILE)

    # 确定标签数量
    num_labels = max(test_labels) + 1
    print(f"检测到的最大标签值: {num_labels-1}，设置num_labels={num_labels}")

    # 创建与训练时相同的模型架构
    model = EnhancedBertForSequenceClassification('bert-base-uncased', num_labels=num_labels)

    # 加载训练好的模型权重
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        print(f"模型权重已加载: {MODEL_PATH}")
    else:
        print(f"模型文件不存在: {MODEL_PATH}")
        return

    model.to(device)

    print("Evaluating model...")
    test_dataloader = create_data_loader(test_dataset, batch_size=BATCH_SIZE)
    metrics = evaluate_model(model, test_dataloader, device)

    print("\nEvaluation Results:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1-Score: {metrics['f1']:.4f}")

    # 示例测试文本
    print("\nExample Predictions:")
    test_texts = [
        "i started out feeling discouraged this morning",
        "i am feeling better though i dont sound it",
        "i feel very mislead by someone that i really really thought i knew and liked very much so",
        "i feel petty all of a sudden",
        "i don t want them to feel so pressured",
        "i love and captured an atmospheric feeling in their landscapes that really impressed me"
    ]

    class_mappings = {
        0: "悲伤",
        1: "快乐",
        2: "爱",
        3: "愤怒",
        4: "恐惧",
        5: "惊讶"
    }

    for text in test_texts:
        result = predict_text(model, tokenizer, text, device, MAX_LENGTH)
        predicted_class = result['predicted_class']
        sentiment = class_mappings.get(predicted_class, f"类别{predicted_class}")

        print(f"\nText: {text}")
        print(f"Predicted Class: {predicted_class} ({sentiment})")
        print(f"Confidence: {result['confidence']:.4f}")

        # 打印各类别概率
        print("Class probabilities:")
        for i, prob in enumerate(result['probabilities']):
            sent = class_mappings.get(i, f"类别{i}")
            print(f"  {sent}: {prob:.4f}")

if __name__ == "__main__":
    main()

Using device: cpu
Loading test data...
文件 test.csv 中的标签分布:
唯一标签值: [0 1 2 3 4 5]
标签计数: {np.int64(0): 580, np.int64(1): 695, np.int64(2): 159, np.int64(3): 275, np.int64(4): 224, np.int64(5): 66}
检测到的最大标签值: 5，设置num_labels=6
模型权重已加载: model_output/best_model_acc.pt
Evaluating model...

Evaluation Results:
Accuracy: 0.8599
Precision: 0.7968
Recall: 0.8826
F1-Score: 0.8277

Example Predictions:

Text: i started out feeling discouraged this morning
Predicted Class: 0 (悲伤)
Confidence: 0.9874
Class probabilities:
  悲伤: 0.9874
  快乐: 0.0038
  爱: 0.0006
  愤怒: 0.0039
  恐惧: 0.0039
  惊讶: 0.0004

Text: i am feeling better though i dont sound it
Predicted Class: 1 (快乐)
Confidence: 0.9921
Class probabilities:
  悲伤: 0.0044
  快乐: 0.9921
  爱: 0.0009
  愤怒: 0.0007
  恐惧: 0.0010
  惊讶: 0.0010

Text: i feel very mislead by someone that i really really thought i knew and liked very much so
Predicted Class: 0 (悲伤)
Confidence: 0.7356
Class probabilities:
  悲伤: 0.7356
  快乐: 0.0107
  爱: 0.1675
  愤怒: 0.0798
  恐惧: 0.00