In [40]:
from sklearn.metrics import roc_auc_score
from utils import *
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from matplotlib import rcParams
# 设置中文字体支持
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['figure.dpi'] = 300
# ----------------------
# 训练函数: 分类模型
# ----------------------
def train_species_model(model, train_loader, val_loader, device, class_weights=None,
                        epochs=300, patience=20, log_interval=10, threshold=0.3):
    """训练藻类种类多标签分类模型"""
    # 计算各类别的样本权重
    if class_weights is None:
        # 计算每个类别的正例数
        positive_counts = torch.zeros(model.net[-1].out_features)
        for batch in train_loader:
            species = batch['species']
            positive_counts += species.sum(dim=0)

        # 计算类别权重: 负例/正例比率
        num_samples = len(train_loader.dataset)
        pos_weights = (num_samples - positive_counts) / (positive_counts + 1e-5)  # 防止除零
        pos_weights = pos_weights.to(device)
    else:
        pos_weights = torch.tensor(class_weights, device=device)

    # 定义损失函数 - 使用BCEWithLogitsLoss，内置class_weight处理类别不平衡
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

    # 替代选项：Focal Loss处理严重不平衡数据
    criterion = FocalLoss(gamma=2.0, alpha=pos_weights)

    # 定义优化器和学习率调度器
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-6)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

    # 跟踪最佳模型
    best_val_loss = float('inf')  # 使用最小验证损失作为指标
    best_model_state = None
    patience_counter = 0
    train_losses = []
    train_f1s = []
    val_losses = []
    val_f1s = []

    for epoch in range(epochs):
        # 训练阶段
        model.train()
        epoch_loss = 0.0

        for batch in train_loader:
            inputs = batch['features'].to(device)
            targets = batch['species'].to(device)

            optimizer.zero_grad()
            logits = model(inputs)

            loss = criterion(logits, targets)
            loss.backward()

            # 梯度裁剪，防止梯度爆炸
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            epoch_loss += loss.item()

        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # 验证阶段
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for batch in val_loader:
                inputs = batch['features'].to(device)
                targets = batch['species'].to(device)

                logits = model(inputs)
                loss = criterion(logits, targets)
                val_loss += loss.item()

                # 转换为概率，然后二值化
                probs = torch.sigmoid(logits)
                preds = (probs > threshold).float().cpu().numpy()

                all_preds.extend(preds)
                all_targets.extend(targets.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # 计算F1分数（仍然保留，但不用于模型选择）
        val_f1 = f1_score(np.array(all_targets), np.array(all_preds), average='micro')
        val_f1s.append(val_f1)

        # 计算训练集F1分数
        all_preds_train = []
        all_targets_train = []
        with torch.no_grad():
            for batch in train_loader:
                inputs = batch['features'].to(device)
                targets = batch['species'].to(device)

                logits = model(inputs)

                # 转换为概率，然后二值化
                probs = torch.sigmoid(logits)
                preds = (probs > threshold).float().cpu().numpy()

                all_preds_train.extend(preds)
                all_targets_train.extend(targets.cpu().numpy())
        train_f1 = f1_score(np.array(all_targets_train), np.array(all_preds_train), average='micro')
        train_f1s.append(train_f1)

        # 更新学习率
        scheduler.step(avg_val_loss)

        # 记录训练过程
        if (epoch + 1) % log_interval == 0:
            print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | "
                  f"Val Loss: {avg_val_loss:.4f} | Val F1: {val_f1:.4f}")

        # 早停检查 - 使用验证损失决定最佳模型
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"早停触发于epoch {epoch+1}")
                break

    # 加载最佳模型
    model.load_state_dict(best_model_state)

    return model, train_losses, val_losses, val_f1s, train_f1s

def evaluate_species_model(model, test_loader, device, mlb, threshold=0.3, find_optimal_threshold=False):
    """评估藻类种类分类模型"""
    model.eval()
    all_logits = []
    all_targets = []

    with torch.no_grad():
        for batch in test_loader:
            inputs = batch['features'].to(device)
            targets = batch['species'].to(device)

            logits = model(inputs)

            all_logits.extend(logits.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    # 转换为numpy数组
    all_logits = np.array(all_logits)
    all_targets = np.array(all_targets)
    all_probs = 1 / (1 + np.exp(-all_logits))  # sigmoid

    # 找到最佳阈值 (可选)
    if find_optimal_threshold:
        best_f1 = 0
        best_threshold = threshold
        for t in np.arange(0.1, 0.9, 0.05):
            preds = (all_probs > t).astype(int)
            f1 = f1_score(all_targets, preds, average='macro')
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = t
        print(f"找到最佳阈值: {best_threshold:.2f}, F1: {best_f1:.4f}")
        threshold = best_threshold

    # 使用阈值获取预测
    all_preds = (all_probs > threshold).astype(int)

    # 计算评估指标
    micro_f1 = f1_score(all_targets, all_preds, average='micro')
    macro_f1 = f1_score(all_targets, all_preds, average='macro')

    # 计算AUC指标
    try:
        micro_auc = roc_auc_score(all_targets, all_probs, average='micro')
        macro_auc = roc_auc_score(all_targets, all_probs, average='macro')
    except ValueError as e:
        print(f"AUC计算异常: {e}")
        micro_auc = macro_auc = 0.0

    print(f"\n藻类种类分类模型评估结果 (阈值 = {threshold}):")
    print(f"Micro-F1: {micro_f1:.4f} | Macro-F1: {macro_f1:.4f}")
    print(f"Micro-AUC: {micro_auc:.4f} | Macro-AUC: {macro_auc:.4f}")

    # 按类别计算指标
    class_names = mlb.classes_
    class_f1 = {}
    class_auc = {}
    pos_samples = np.sum(all_targets, axis=0)  # 每个类别的正样本数

    for i, class_name in enumerate(class_names):
        # 只计算有正样本的类别
        if pos_samples[i] > 0:
            # F1分数
            class_f1[class_name] = f1_score(all_targets[:, i], all_preds[:, i])
            # AUC分数
            try:
                class_auc[class_name] = roc_auc_score(all_targets[:, i], all_probs[:, i])
            except ValueError:
                class_auc[class_name] = np.nan
        else:
            class_f1[class_name] = np.nan
            class_auc[class_name] = np.nan

    # 显示每个类别的F1分数 (降序排列)
    sorted_classes = sorted([(k, v, class_auc[k]) for k, v in class_f1.items()],
                          key=lambda x: x[1], reverse=True)

    print("\n各藻类种类的F1分数和AUC:")
    for name, f1, auc in sorted_classes:
        print(f"{name}: F1={f1:.4f} | AUC={auc:.4f}")
    # 可视化优化

    return micro_f1, macro_f1, micro_auc, macro_auc, all_probs, all_targets, class_f1, class_auc, sorted_classes

import matplotlib.patheffects as pe
# ----------------------
# 阈值优化函数
# ----------------------
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
# 设置全局字体为 Times New Roman
matplotlib.rcParams['font.family'] = 'Times New Roman'

def optimize_thresholds(probs, targets, mlb):
    """为每个类别找到最佳阈值并可视化阈值-F1/AUC关系"""
    optimal_thresholds = {}
    class_names = mlb.classes_
    n_classes = len(class_names)

    # 初始化可视化画布
    plt.figure(figsize=(18, 16))
    sns.set_style("whitegrid")
    matplotlib.rcParams['font.family'] = 'Times New Roman'
    palette = sns.color_palette("husl", n_colors=n_classes)
    # ======================
    # 第一幅图：F1-阈值曲线（优化版）
    # ======================
    plt.subplot(2, 1, 1)
    rcParams.update({
    'font.size': 16,
    'axes.labelsize': 24,
    'axes.titlesize': 28,
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
    'font.family': 'Times New Roman',
    'figure.dpi': 300,
    'axes.titlepad': 20,
    'axes.labelpad': 12})
    # 高级配色方案
    palette = plt.cm.viridis(np.linspace(0.1, 0.9, n_classes))
    line_styles = ['-', '--', '-.', ':'] * (n_classes//4 + 1)

    # 绘制每个类别的曲线
    for i, class_name in enumerate(class_names):
        y_true = targets[:, i]
        y_prob = probs[:, i]

        if np.sum(y_true) < 1:
            continue

        # 计算F1曲线
        thresholds = np.linspace(0, 1, 100)  # 增加采样点
        f1_scores = [f1_score(y_true, (y_prob > t).astype(int), zero_division=0)
                    for t in thresholds]
        best_idx = np.argmax(f1_scores)
        best_thresh = thresholds[best_idx]
        best_f1 = f1_scores[best_idx]

        # 记录最佳阈值
        optimal_thresholds[class_name] = {'threshold': best_thresh, 'f1_score': best_f1}

        # 绘制主曲线（带阴影置信区间）
        plt.plot(thresholds, f1_scores,
                color=palette[i],
                linestyle=line_styles[i],
                lw=2.5,
                alpha=0.9,
                label=f'{class_name}',
                path_effects=[pe.Stroke(linewidth=4, foreground='white'), pe.Normal()])  # 白色描边

        # 绘制半透明区域
        plt.fill_between(thresholds, f1_scores,
                        alpha=0.15,
                        color=palette[i])

        # 最佳点标注（带箭头）
        plt.annotate(f'{best_f1:.2f}',
                    xy=(best_thresh, best_f1),
                    xytext=(best_thresh+0.05, best_f1-0.1),
                    arrowprops=dict(arrowstyle="->",
                                  color=palette[i],
                                  connectionstyle="arc3,rad=-0.2"),
                    fontsize=12,
                    bbox=dict(boxstyle="round",
                            facecolor='white',
                            edgecolor=palette[i],
                            alpha=0.8))
    # ======================
    # 图表装饰优化
    # ======================
    # 坐标轴设置
    plt.xlabel('Classification Threshold',
              labelpad=12,
              )
    plt.ylabel('F1 Score',
              labelpad=12,
              )
    plt.xticks(np.arange(0, 1.1, 0.1),
              )
    plt.yticks(np.arange(0, 1.1, 0.1),
              )
    plt.ylim(0, 1.05)
    plt.xlim(0.0, 0.9)
    # 参考线增强
    for y in [0.7, 0.9]:
        plt.axhline(y,
                   color='#34495e',
                   linestyle=':',
                   lw=1.5,
                   alpha=0.6,
                   zorder=0)
        plt.text(0.01, y+0.02,
                f'Benchmark (F1={y})',
                fontsize=18,
                color='#34495e',
                va='bottom')

    # 图例优化
    legend = plt.legend(loc='upper center',
                      bbox_to_anchor=(0.5, -0.15),
                      ncol=4,
                      frameon=True,
                      fontsize=22,
                      title='Class Legend',
                      title_fontsize=22)
    legend.get_frame().set_facecolor('#f8f9fa')
    legend.get_frame().set_edgecolor('#dee2e6')

    # 标题设置
    plt.title('F1 Score Dynamics Across Classification Thresholds',
             fontsize=32,
             pad=25,
             fontweight='bold',
             color='#2c3e50')

    # 网格优化
    plt.grid(True,
            axis='both',
            alpha=0.25,
            linestyle='--',
            color='#95a5a6')

    # 边框优化
    for spine in plt.gca().spines.values():
        spine.set_color('#2c3e50')
        spine.set_linewidth(1.5)

    # 第二幅图：AUC分布（优化版本）
    # ======================
    plt.subplot(2, 1, 2)

    # 计算AUC数据
    auc_data = []
    for i, class_name in enumerate(class_names):
        y_true = targets[:, i]
        y_prob = probs[:, i]
        if len(np.unique(y_true)) >= 2:
            fpr, tpr, _ = roc_curve(y_true, y_prob)
            roc_auc = auc(fpr, tpr)
            auc_data.append( (class_name, roc_auc) )

    # 排序并准备数据
    auc_data.sort(key=lambda x: x[1])
    sorted_classes = [x[0] for x in auc_data]
    sorted_auc = [x[1] for x in auc_data]

    # 创建渐变颜色
    colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(sorted_auc)))

    # 绘制水平条形图
    bars = plt.barh(range(len(sorted_auc)), sorted_auc,
                    height=0.7,
                    color=colors,
                    edgecolor='#2d3436',
                    linewidth=1.2,
                    alpha=0.85)

    # 添加阈值参考线（优化部分）
    THRESHOLD_GOOD = 0.7
    THRESHOLD_EXCELLENT = 0.9
    plt.axvline(THRESHOLD_GOOD, color='#e74c3c', linestyle='--',
               linewidth=2.5, alpha=0.8, zorder=0)
    plt.axvline(THRESHOLD_EXCELLENT, color='#27ae60', linestyle='--',
               linewidth=2.5, alpha=0.8, zorder=0)

    # 添加阈值标注
    plt.text(THRESHOLD_GOOD+0.005, len(sorted_auc)-0.5,
            'Good (0.7)',
            color='#e74c3c', fontsize=20, va='center')
    plt.text(THRESHOLD_EXCELLENT+0.005, len(sorted_auc)-0.5,
            'Excellent (0.9)',
            color='#27ae60', fontsize=20, va='center')

    # 添加数据标签（优化样式）
    for i, (bar, auc_val) in enumerate(zip(bars, sorted_auc)):
        plt.text(bar.get_width() + 0.015,
                bar.get_y() + bar.get_height()/2,
                f'{auc_val:.2f}',
                va='center',
                fontsize=18,
                fontweight='bold',
                color=colors[i],
                bbox=dict(facecolor='white', alpha=0.8,
                          edgecolor='none', pad=0.3))

    # 优化坐标轴
    plt.yticks(range(len(sorted_auc)), sorted_classes,
             fontsize=24)
    plt.xticks(np.arange(0, 1.1, 0.1),
             fontsize=24)
    plt.xlabel('AUC Score', labelpad=15)
    plt.xlim(0.4, 1.05)

    # 添加装饰元素
    plt.title('AUC Distribution with Performance Thresholds',
             fontsize=32, pad=25, fontweight='bold')
    plt.grid(True, axis='x', alpha=0.4, linestyle=':')

    # 调整布局
    plt.tight_layout()
    plt.savefig('threshold_optimization_analysis_v2.png',
               dpi=300,
               bbox_inches='tight')
    plt.close()

    return optimal_thresholds


def plot_evaluate(sorted_classes):
        # 可视化部分修改
    plt.rcParams['figure.dpi'] = 300
    sns.set_style("whitegrid")
    plt.rcParams['font.family'] = 'Times New Roman'

    # 创建画布和双轴
    fig, ax1 = plt.subplots(figsize=(18, 10))
    ax2 = ax1.twinx()

    # 标题居中
    plt.title('Algae Species Classification Performance\nF1 Scores & AUC Metrics (Threshold=    for all algae)',
              fontsize=24,
              fontweight='bold',
              pad=20)

    # 数据准备
    class_names = [x[0] for x in sorted_classes]
    f1_scores = [x[1] for x in sorted_classes]
    auc_scores = [x[2] for x in sorted_classes]

    # 获取排序后的正样本数量
    sorted_indices = [mlb.classes_.tolist().index(name) for name in class_names]
    sorted_pos_samples = pos_samples[sorted_indices]
    max_samples = np.max(sorted_pos_samples)
    colors = plt.cm.viridis(sorted_pos_samples / max_samples)

    # 坐标轴参数
    x = np.arange(len(class_names))
    bar_width = 0.35

    # 绘制双轴条形图
    bars_f1 = ax1.bar(x - bar_width/2,
                      f1_scores,
                      width=bar_width,
                      color=colors,
                      edgecolor='#2d3436',
                      linewidth=1.5,
                      alpha=0.85,
                      label='F1 Score')

    bars_auc = ax2.bar(x + bar_width/2,
                       auc_scores,
                       width=bar_width,
                       color=colors,
                       edgecolor='#2d3436',
                       linewidth=1.5,
                       alpha=0.85,
                       hatch='//',
                       label='AUC Score')

    # 添加参考线
    ax1.axhline(0.7, color='gray', linestyle='--', linewidth=1.5, alpha=0.7)
    ax1.axhline(0.9, color='gray', linestyle='--', linewidth=1.5, alpha=0.7)

    # 坐标轴设置
    ax1.set_ylabel('F1 Score', fontsize=20, labelpad=15, fontweight='medium')
    ax1.set_ylim(0, 1.2)
    ax1.set_yticks(np.arange(0, 1.1, 0.1))
    ax1.tick_params(axis='y', labelsize=14)

    ax2.set_ylabel('AUC Score', fontsize=20, labelpad=15, fontweight='medium')
    ax2.set_ylim(0, 1.2)
    ax2.set_yticks(np.arange(0, 1.1, 0.1))
    ax2.tick_params(axis='y', labelsize=14)

    ax1.set_xticks(x)
    ax1.set_xticklabels(class_names,
                        rotation=45,
                        ha='right',
                        fontsize=20,
                        fontweight='medium')

    # 红色数值标注
    for i, (f1, auc) in enumerate(zip(f1_scores, auc_scores)):
        ax1.text(i - bar_width/2, f1 + 0.02,
                 f'{f1:.2f}',
                 ha='center',
                 va='bottom',
                 fontsize=14,
                 color='red',
                 weight='bold')

        ax2.text(i + bar_width/2, auc + 0.02,
                 f'{auc:.2f}',
                 ha='center',
                 va='bottom',
                 fontsize=14,
                 color='red',
                 weight='bold')

    # 颜色条外置
    sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis,
                               norm=plt.Normalize(0, max_samples))
    sm.set_array([])
    cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    cbar = plt.colorbar(sm, cax=cax)
    cbar.set_label('Positive Samples',
                   fontsize=20,
                   labelpad=15,
                   fontweight='medium')
    cbar.ax.tick_params(labelsize=20)

    # 新图例设置（白色背景带边框）
    legend_handles = [
        plt.Rectangle((0,0), 1, 1,
                      facecolor='white',
                      edgecolor='#2d3436',
                      linewidth=1.5,
                      alpha=0.85,
                      label='F1 Score'),
        plt.Rectangle((0,0), 1, 1,
                      facecolor='white',
                      edgecolor='#2d3436',
                      linewidth=1.5,
                      alpha=0.85,
                      hatch='//',
                      label='AUC Score')
    ]

    ax1.legend(handles=legend_handles,
               loc='upper left',
               fontsize=14,
               frameon=True,
               framealpha=0.9,
               edgecolor='black')

    plt.subplots_adjust(right=0.85)
    plt.savefig("combined_metrics_v3.png",
                dpi=300,
                bbox_inches='tight',
                facecolor='white')
    plt.close()

def plot_combined_roc(probs, targets, mlb, filename="combined_roc_curves.png"):
    """绘制七个类别的ROC曲线集成图"""
    # ======================
    # 1. 初始化配置
    # ======================
    # 中文字体配置
    rcParams['font.family'] = 'Times New Roman'


    # 创建画布
    plt.figure(figsize=(12, 10))
    ax = plt.gca()

    # ======================
    # 2. 样式参数设置
    # ======================
    # 颜色方案 (使用7种高区分度颜色)
    colors = [
        '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
        '#9467bd', '#8c564b', '#e377c2'
    ]

    # 线型配置
    line_styles = ['-', '--', '-.', ':'] * 2  # 7个类别循环使用

    # ======================
    # 3. 绘制基准线
    # ======================
    plt.plot([0, 1], [0, 1], 'k--', lw=1.5, alpha=0.6, label='Random guess')

    # ======================
    # 4. 遍历绘制每个类别
    # ======================
    auc_scores = []
    for i, class_name in enumerate(mlb.classes_[:7]):  # 确保只处理前7个类别
        y_true = targets[:, i]
        y_prob = probs[:, i]

        # 跳过无效数据
        if len(np.unique(y_true)) < 2:
            continue

        # 计算ROC曲线
        fpr, tpr, _ = roc_curve(y_true, y_prob)
        roc_auc = auc(fpr, tpr)
        auc_scores.append(roc_auc)

        # 绘制曲线
        plt.plot(fpr, tpr,
                 color=colors[i],
                 linestyle=line_styles[i],
                 lw=2.5,
                 alpha=0.9,
                 label=f'{class_name} (AUC={roc_auc:.2f})')

    # ======================
    # 5. 图表装饰
    # ======================
    # 坐标轴设置
    plt.xlim([-0.02, 1.02])
    plt.ylim([-0.02, 1.02])
    plt.xlabel('False positive rate (FPR)', fontsize=13, fontweight='bold', labelpad=12)
    plt.ylabel('True positive rate (TPR)', fontsize=13, fontweight='bold', labelpad=12)

    # 标题设置
    plt.title('Comparison of ROC curves by category', fontsize=16, pad=20,
             fontweight='bold', color='#2d3436')

    # 图例设置
    legend = plt.legend(loc='lower right',
                      frameon=True,
                      fontsize=11,
                      ncol=2,
                      columnspacing=0.8,
                      handlelength=2.5,
                      borderpad=0.6,
                      edgecolor='#dfe6e9')
    legend.get_frame().set_facecolor('#ffffff')

    # 网格设置
    plt.grid(True, alpha=0.3, linestyle=':', color='#636e72')

    # 边框设置
    for spine in ax.spines.values():
        spine.set_edgecolor('#2d3436')
        spine.set_linewidth(1.2)

    # ======================
    # 6. 添加统计信息
    # ======================
    plt.text(0.6, 0.18,
            f'Average AUC: {np.mean(auc_scores):.2f}',
            fontsize=12,
            bbox=dict(facecolor='white', alpha=0.8, edgecolor='#b2bec3'))

    # ======================
    # 7. 保存输出
    # ======================
    plt.tight_layout()
    plt.savefig(filename,
               dpi=300,
               bbox_inches='tight',
               facecolor='white')
    plt.close()

    return auc_scores



In [39]:
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['figure.dpi'] = 300
# 配置参数
BATCH_SIZE = 4
EPOCHS = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 加载数据
df = pd.read_excel("./datasets.xlsx")


print(f"清洗后数据形状: {df.shape}")

# 数据预处理
X_train, X_test, ys_train, ys_test, mlb, x_scaler, feature_names = preprocess_data_class(df)

print(f"训练集样本数: {X_train.shape[0]}")
print(f"测试集样本数: {X_test.shape[0]}")
print(f"特征数量: {X_train.shape[1]}")
print(f"藻类种类数量: {ys_train.shape[1]}")


# ------------------
# 2. 种类分类模型
# ------------------
print("\n开始训练藻类种类分类模型...")

# 创建数据加载器
train_species_dataset = AlgaeSpeciesDataset(X_train, ys_train)
test_species_dataset = AlgaeSpeciesDataset(X_test, ys_test)
train_species_loader = DataLoader(train_species_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_species_loader = DataLoader(test_species_dataset, batch_size=BATCH_SIZE)

# 计算类别权重
pos_samples = ys_train.sum(axis=0)
neg_samples = len(ys_train) - pos_samples
pos_weights = neg_samples / (pos_samples + 1e-5)  # 防止除零

print("\n类别分布和正样本权重:")
for i, class_name in enumerate(mlb.classes_):
    print(f"{class_name}: {pos_samples[i]} 正样本, 权重 = {pos_weights[i]:.2f}")

# 创建并训练种类分类模型
species_model = AlgaeSpeciesClassifier(
    input_size=X_train.shape[1],
    num_species=len(mlb.classes_)
).to(device)



species_model, s_train_losses, s_val_losses, s_val_f1s, s_train_fis = train_species_model(
    model=species_model,
    train_loader=train_species_loader,
    val_loader=test_species_loader,
    device=device,
    class_weights=pos_weights,
    epochs=EPOCHS,
    patience=10
)

# 保存分类模型
torch.save(species_model.state_dict(), "algae_species_model.pth")

# 评估分类模型，自动寻找最佳阈值micro_f1, macro_f1, micro_auc, macro_auc, all_probs, all_targets, class_f1, class_auc
micro_f1, macro_f1, micro_auc, macro_auc,species_probs, species_targets, class_f1, class_auc,sorted_classes = evaluate_species_model(
    model=species_model,
    test_loader=test_species_loader,
    device=device,
    mlb=mlb,
    find_optimal_threshold=True
)
plot_evaluate(sorted_classes)

# 为每个类别找到最佳阈值
optimal_thresholds = optimize_thresholds(species_probs, species_targets, mlb)
auc_scores = plot_combined_roc(species_probs, species_targets, mlb)

print("\n每个类别的最佳阈值:")
for class_name, info in optimal_thresholds.items():
    print(f"{class_name}: 阈值 = {info['threshold']:.2f}, F1 = {info['f1_score']:.4f}")

sns.set_style("whitegrid")
matplotlib.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['figure.dpi'] = 300
plt.figure(figsize=(12, 5))

# 第一个子图：损失曲线
plt.subplot(1, 2, 1)
plt.plot(s_train_losses, label='Train Loss', color='blue', linestyle='-', marker='o', markersize=4)
plt.plot(s_val_losses, label='Test Loss', color='red', linestyle='--', marker='s', markersize=4)
plt.title('Loss Curves', fontsize=20,fontweight='bold')
plt.xlabel('Epochs', fontsize=16,fontweight='medium')
plt.ylabel('Loss', fontsize=16,fontweight='medium')
plt.legend()
plt.grid(True)

# 第二个子图：F1分数曲线
plt.subplot(1, 2, 2)
plt.plot(s_train_fis, label='Train F1', color='orange', linestyle='--', marker='s', markersize=4)
plt.plot(s_val_f1s, label='Test F1', color='green', linestyle='-', marker='o', markersize=4)
plt.title('Micro F1 Score Curves', fontsize=20,fontweight='bold')
plt.xlabel('Epochs', fontsize=16,fontweight='medium')
plt.ylabel('Micro F1 Score', fontsize=16,fontweight='medium')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig("species_training_history.png")
plt.close()

print("\n模型训练和评估完成")


使用设备: cuda
清洗后数据形状: (114, 14)
原始数据形状: (114, 14)
过滤后藻类种类标签数量: 7
训练集样本数: 102
测试集样本数: 12
特征数量: 12
藻类种类数量: 7

开始训练藻类种类分类模型...

类别分布和正样本权重:
Anabaena: 12 正样本, 权重 = 7.50
Chroomonas: 14 正样本, 权重 = 6.29
Cryptomonas: 41 正样本, 权重 = 1.49
Cyclotella: 21 正样本, 权重 = 3.86
Limnothrix: 24 正样本, 权重 = 3.25
Melosira: 8 正样本, 权重 = 11.75
Pseudanabaena: 44 正样本, 权重 = 1.32
Epoch 10/300 | Train Loss: 0.3864 | Val Loss: 0.4566 | Val F1: 0.5846
Epoch 20/300 | Train Loss: 0.2749 | Val Loss: 0.5270 | Val F1: 0.6429
早停触发于epoch 24
找到最佳阈值: 0.40, F1: 0.6906

藻类种类分类模型评估结果 (阈值 = 0.40000000000000013):
Micro-F1: 0.6923 | Macro-F1: 0.6906
Micro-AUC: 0.8904 | Macro-AUC: 0.9085

各藻类种类的F1分数和AUC:
Limnothrix: F1=0.8571 | AUC=0.9630
Chroomonas: F1=0.7500 | AUC=0.8286
Cryptomonas: F1=0.7273 | AUC=0.9375
Anabaena: F1=0.6667 | AUC=0.9091
Melosira: F1=0.6667 | AUC=1.0000
Pseudanabaena: F1=0.6667 | AUC=0.9062
Cyclotella: F1=0.5000 | AUC=0.8148

每个类别的最佳阈值:
Anabaena: 阈值 = 0.35, F1 = 0.6667
Chroomonas: 阈值 = 0.22, F1 = 0.8889
Cryptomonas: 阈值 = 

In [None]:
import torch
from sklearn.model_selection import StratifiedKFold,KFold
from sklearn.metrics import f1_score
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
def cross_validate_species_model(X, y, mlb, model_class, device,
                                  n_splits=3, batch_size=4, epochs=300, patience=10):
    """
    对藻类种类分类模型进行交叉验证

    参数:
    - X: 特征数据
    - y: 标签数据
    - mlb: 多标签二值化器
    - model_class: 模型类
    - device: 计算设备
    - n_splits: 交叉验证折数
    - batch_size: 批次大小
    - epochs: 训练轮数
    - patience: 早停耐心值

    返回:
    - 交叉验证结果字典
    """
    # 修改点1：导入多标签分层K折
    from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

    # 修改点2：使用多标签分层K折
    mskf = MultilabelStratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    # 初始化结果存储
    cv_results = {
        'fold_train_losses': [],
        'fold_val_losses': [],
        'fold_val_f1s': [],
        'fold_micro_f1': [],
        'fold_macro_f1': [],
        'fold_class_f1': [],
        'fold_train_f1s': [],
        'fold_micro_auc': [],
        'fold_macro_auc': []
    }

    # 计算类别权重（全局）
    pos_samples = y.sum(axis=0)
    neg_samples = len(y) - pos_samples
    pos_weights = neg_samples / (pos_samples + 1e-5)

    # 交叉验证主循环
    # 修改点3：使用多标签数据进行split
    for fold, (train_index, val_index) in enumerate(mskf.split(X, y), 1):
        print(f"\n===== 第 {fold} 折 =====")

        # 划分训练和验证集
        X_train_fold, X_val_fold = X[train_index], X[val_index]
        y_train_fold, y_val_fold = y[train_index], y[val_index]


        # 创建数据加载器
        train_dataset = AlgaeSpeciesDataset(X_train_fold, y_train_fold)
        val_dataset = AlgaeSpeciesDataset(X_val_fold, y_val_fold)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        # 初始化模型
        model = model_class(
            input_size=X_train_fold.shape[1],
            num_species=len(mlb.classes_)
        ).to(device)

        # 训练模型
        trained_model, train_losses, val_losses, val_f1s, train_f1s = train_species_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device,
            class_weights=pos_weights,
            epochs=epochs,
            patience=patience
        )

        # 模型评估
        micro_f1, macro_f1, micro_auc, macro_auc, probs, targets, class_f1, class_auc, sorted_classes = evaluate_species_model(
            model=trained_model,
            test_loader=val_loader,
            device=device,
            mlb=mlb,
            find_optimal_threshold=True
        )

        # 记录结果
        cv_results['fold_train_losses'].append(train_losses)
        cv_results['fold_val_losses'].append(val_losses)
        cv_results['fold_val_f1s'].append(val_f1s)
        cv_results['fold_train_f1s'].append(train_f1s)
        cv_results['fold_micro_f1'].append(micro_f1)
        cv_results['fold_macro_f1'].append(macro_f1)
        cv_results['fold_micro_auc'].append(micro_auc)
        cv_results['fold_macro_auc'].append(macro_auc)

        print(f"Fold {fold} - Micro F1: {micro_f1:.4f}, Macro F1: {macro_f1:.4f}")

    # 打印交叉验证总结
    # 打印交叉验证总结
    print("\n交叉验证结果总结:")
    print(f"平均 Micro F1: {np.mean(cv_results['fold_micro_f1']):.4f} ± {np.std(cv_results['fold_micro_f1']):.4f}")
    print(f"平均 Macro F1: {np.mean(cv_results['fold_macro_f1']):.4f} ± {np.std(cv_results['fold_macro_f1']):.4f}")
    print(f"平均 Micro AUC: {np.mean(cv_results['fold_micro_auc']):.4f} ± {np.std(cv_results['fold_micro_auc']):.4f}")
    print(f"平均 Macro AUC: {np.mean(cv_results['fold_macro_auc']):.4f} ± {np.std(cv_results['fold_macro_auc']):.4f}")


    return cv_results

# 在主代码中调用交叉验证
cv_results = cross_validate_species_model(
    X=X_train,  # 训练数据
    y=ys_train,  # 标签
    mlb=mlb,     # 多标签二值化器
    model_class=AlgaeSpeciesClassifier,  # 模型类
    device=device,
    n_splits=5,  # 交叉验证折数
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    patience=10
)
import matplotlib.pyplot as plt
import numpy as np

import numpy as np
import matplotlib.pyplot as plt

def plot_cv_results(cv_results):
    """绘制交叉验证结果"""

    # 创建一个2x2的大图
    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
    fig.suptitle('Cross-validation Results - Species Classification', fontsize=30, fontweight='bold')

    # 获取训练和验证历史记录的最小长度
    min_history_length = min(len(losses) for losses in cv_results['fold_train_losses'])

    # 颜色方案
    train_loss_color = '#6A5ACD'  # 蓝紫色
    val_loss_color = '#FF6347'    # 番茄红
    train_f1_color = '#20B2AA'    # 青绿色
    val_f1_color = '#FFA500'      # 橙色
    bar_micro_color = '#4682B4'   # 深蓝色
    bar_macro_color = '#DC143C'   # 猩红色
    auc_micro_color = '#32CD32'   # 酸橙绿
    auc_macro_color = '#FFD700'   # 金色
    rcParams.update({
    'font.size': 14,
    'axes.labelsize': 20,
    'axes.titlesize': 24,
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
    'font.family': 'Times New Roman',
    'figure.dpi': 300,
    'axes.titlepad': 20,
    'axes.labelpad': 12})

    # 1. 损失曲线 (左上角)
    ax1 = axes[0, 0]
    for losses in cv_results['fold_train_losses']:
        ax1.plot(losses[:min_history_length], alpha=0.3, color=train_loss_color)
    mean_train_loss = np.mean([np.array(losses[:min_history_length]) for losses in cv_results['fold_train_losses']], axis=0)
    ax1.plot(mean_train_loss, linewidth=2.5, color=train_loss_color, label='Average Training Loss')

    for losses in cv_results['fold_val_losses']:
        ax1.plot(losses[:min_history_length], alpha=0.3, color=val_loss_color)
    mean_val_loss = np.mean([np.array(losses[:min_history_length]) for losses in cv_results['fold_val_losses']], axis=0)
    ax1.plot(mean_val_loss, linewidth=2.5, color=val_loss_color, label='Average Test Loss')

    ax1.set_title('Loss Curve',  fontweight='bold')
    ax1.set_xlabel('Training Epoch', fontweight='medium')
    ax1.set_ylabel('Focal Loss', fontweight='medium')
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.legend(fontsize=14)

    # 2. F1 分数曲线 (右上角)
    ax2 = axes[0, 1]
    for train_f1s in cv_results['fold_train_f1s']:
        ax2.plot(train_f1s[:min_history_length], alpha=0.3, color=train_f1_color)
    mean_train_f1 = np.mean([np.array(f1s[:min_history_length]) for f1s in cv_results['fold_train_f1s']], axis=0)
    ax2.plot(mean_train_f1, linewidth=2.5, color=train_f1_color, label='Average Training F1')

    for val_f1s in cv_results['fold_val_f1s']:
        ax2.plot(val_f1s[:min_history_length], alpha=0.3, color=val_f1_color)
    mean_val_f1 = np.mean([np.array(f1s[:min_history_length]) for f1s in cv_results['fold_val_f1s']], axis=0)
    ax2.plot(mean_val_f1, linewidth=2.5, color=val_f1_color, label='Average Testing F1')

    ax2.set_title('F1 Score Curve', fontweight='bold')
    ax2.set_xlabel('Training Epoch',  fontweight='medium')
    ax2.set_ylabel('F1 Score', fontsize=16, fontweight='medium')
    ax2.grid(True, alpha=0.3, linestyle='--')
    ax2.legend(fontsize=14)

    # 3. 每折 F1 分数 (左下角)
    ax3 = axes[1, 0]
    n_folds = len(cv_results['fold_micro_f1'])
    folds = [f'Fold {i+1}' for i in range(n_folds)]
    x = np.arange(n_folds)
    width = 0.35

    ax3.bar(x - width/2, cv_results['fold_micro_f1'], width, label='Micro F1', color=bar_micro_color, alpha=0.8)
    ax3.bar(x + width/2, cv_results['fold_macro_f1'], width, label='Macro F1', color=bar_macro_color, alpha=0.8)
    ax3.set_title('F1 Scores per Fold', fontweight='bold')
    ax3.set_xlabel('Fold',  fontweight='medium')
    ax3.set_ylabel('F1 Score',  fontweight='medium')
    ax3.set_xticks(x)
    ax3.set_xticklabels(folds)
    ax3.legend(fontsize=14)

    # 4. 每折 AUC 分数 (右下角)
    ax4 = axes[1, 1]
    ax4.bar(x - width/2, cv_results['fold_micro_auc'], width, label='Micro AUC', color=auc_micro_color, alpha=0.8)
    ax4.bar(x + width/2, cv_results['fold_macro_auc'], width, label='Macro AUC', color=auc_macro_color, alpha=0.8)
    ax4.set_title('AUC Score per Fold', fontweight='bold')
    ax4.set_xlabel('Fold',  fontweight='medium')
    ax4.set_ylabel('AUC Score', fontweight='medium')
    ax4.set_xticks(x)
    ax4.set_xticklabels(folds)
    ax4.legend(fontsize=14)

    # 调整布局并保存
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # 留出空间给主标题
    plt.savefig("species_cv_results.png", dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

# 调用函数绘制结果
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['figure.dpi'] = 300
plot_cv_results(cv_results)

