<a href="https://colab.research.google.com/github/Hijuli66/33/blob/master/MobileNetV3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.metrics import accuracy_score, f1_score
from PIL import Image
from google.colab import drive

# 挂载 Google Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive', force_remount=True)
else:
    print("Google Drive 已挂载")

# 设置路径
drive_path = '/content/drive/MyDrive/33/'
data_path = os.path.join(drive_path, 'Data/')
model_path = os.path.join(drive_path, 'Images/Models/')
os.makedirs(model_path, exist_ok=True)

# 1. 加载和准备数据
try:
    train_df = pd.read_csv(os.path.join(data_path, 'train.csv'))
    val_df = pd.read_csv(os.path.join(data_path, 'val.csv'))
    test_df = pd.read_csv(os.path.join(data_path, 'test.csv'))
    print("成功加载 train.csv, val.csv 和 test.csv")
    print(f"训练集大小：{len(train_df)}，验证集大小：{len(val_df)}，测试集大小：{len(test_df)}")
except FileNotFoundError as e:
    print(f"错误：未找到数据集文件，{e}")
    raise
except Exception as e:
    print(f"读取数据集时出错：{e}")
    raise

# 检查必要列
required_columns = ['image_id', 'image_path', 'label']
if not all(col in train_df.columns for col in required_columns):
    missing_cols = [col for col in required_columns if col not in train_df.columns]
    print(f"错误：train.csv 缺少以下必要列：{missing_cols}")
    raise ValueError("数据集缺少必要列")

# 2. 数据增强
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # MobileNetV3 输入尺寸
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])  # 单通道标准化
])

val_test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])

# 自定义数据集类
class ChestXRayDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = os.path.join(drive_path, self.dataframe.iloc[idx]['image_path'])
        try:
            image = Image.open(img_path).convert('L')  # 灰度图
        except Exception as e:
            print(f"无法加载图像 {img_path}：{e}")
            raise
        label = self.dataframe.iloc[idx]['label']
        if self.transform:
            image = self.transform(image)
        return image, label

# 创建数据加载器
train_dataset = ChestXRayDataset(train_df, transform=train_transforms)
val_dataset = ChestXRayDataset(val_df, transform=val_test_transforms)
test_dataset = ChestXRayDataset(test_df, transform=val_test_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# 3. 加载 MobileNetV3 并修改输入层
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.mobilenet_v3_small(pretrained=True)
model.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)  # 单通道输入
num_features = model.classifier[3].in_features
model.classifier[3] = nn.Linear(num_features, 2)  # 二分类
model = model.to(device)

# 评估函数（返回损失、准确率、F1 分数和预测）
def evaluate_model(model, data_loader, criterion, dataset_name="验证集"):
    model.eval()
    preds, labels = [], []
    running_loss = 0.0
    with torch.no_grad():
        for inputs, labels_batch in data_loader:
            inputs, labels_batch = inputs.to(device), labels_batch.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels_batch)
            running_loss += loss.item() * inputs.size(0)
            _, pred = torch.max(outputs, 1)
            preds.extend(pred.cpu().numpy())
            labels.extend(labels_batch.cpu().numpy())
    loss = running_loss / len(data_loader.dataset)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    print(f"{dataset_name} 损失：{loss:.4f}，准确率：{acc:.4f}，F1 分数：{f1:.4f}")
    return loss, acc, f1, preds

# 训练函数
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, phase, model_save_path):
    best_val_f1 = 0.0
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)

        # 验证阶段（在训练后）
        val_loss, val_acc, val_f1, _ = evaluate_model(model, val_loader, criterion, dataset_name=f"阶段 {phase} Epoch {epoch+1}/{num_epochs} 验证集")

        # 保存最佳模型
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), model_save_path)

        scheduler.step(val_loss)

# 4. 初始验证（训练前）
print("训练前初始验证：评估预训练模型在验证集上的性能")
criterion = nn.CrossEntropyLoss()
evaluate_model(model, val_loader, criterion, dataset_name="初始验证集")

# 5. 阶段 1：完整训练（只训练顶层）
print("开始阶段 1：完整训练（只训练顶层）")
for param in model.parameters():
    param.requires_grad = False
for param in model.classifier.parameters():
    param.requires_grad = True
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, phase=1, model_save_path=os.path.join(model_path, 'mobilenetv3_full_phase1.pth'))

# 6. 阶段 2：完整训练（解冻后几层）
print("开始阶段 2：完整训练（解冻后几层）")
for name, param in model.named_parameters():
    if "features.9" in name or "features.10" in name or "features.11" in name:
        param.requires_grad = True
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, weight_decay=0.0001)
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, phase=2, model_save_path=os.path.join(model_path, 'mobilenetv3_full_phase2.pth'))

# 保存最终模型
torch.save(model.state_dict(), os.path.join(model_path, 'mobilenetv3_final.pth'))
print(f"最终模型已保存至：{model_path}/mobilenetv3_final.pth")

# 7. 测试集评估
print("开始测试集评估")
model.load_state_dict(torch.load(os.path.join(model_path, 'mobilenetv3_final.pth')))
test_loss, test_acc, test_f1, test_preds = evaluate_model(model, test_loader, criterion, dataset_name="测试集")

# 保存测试集预测结果
test_df['prediction'] = test_preds
test_df.to_csv(os.path.join(data_path, 'test_predictions.csv'), index=False)
print(f"测试集预测结果已保存至：{data_path}/test_predictions.csv")

Google Drive 已挂载
成功加载 train.csv, val.csv 和 test.csv
训练集大小：7681，验证集大小：961，测试集大小：961
训练前初始验证：评估预训练模型在验证集上的性能




初始验证集 损失：0.6818，准确率：0.5744，F1 分数：0.5736
开始阶段 1：完整训练（只训练顶层）
阶段 1 Epoch 1/10 验证集 损失：0.5708，准确率：0.6785，F1 分数：0.6532
阶段 1 Epoch 2/10 验证集 损失：0.3345，准确率：0.8491，F1 分数：0.8461
阶段 1 Epoch 3/10 验证集 损失：0.2745，准确率：0.8866，F1 分数：0.8855
阶段 1 Epoch 4/10 验证集 损失：0.2487，准确率：0.9074，F1 分数：0.9069
阶段 1 Epoch 5/10 验证集 损失：0.2840，准确率：0.8762，F1 分数：0.8740
阶段 1 Epoch 6/10 验证集 损失：0.2528，准确率：0.8866，F1 分数：0.8864
阶段 1 Epoch 7/10 验证集 损失：0.2246，准确率：0.9105，F1 分数：0.9099
阶段 1 Epoch 8/10 验证集 损失：0.3257，准确率：0.8491，F1 分数：0.8473
阶段 1 Epoch 9/10 验证集 损失：0.2377，准确率：0.8980，F1 分数：0.8970
阶段 1 Epoch 10/10 验证集 损失：0.2604，准确率：0.8855，F1 分数：0.8853
开始阶段 2：完整训练（解冻后几层）
阶段 2 Epoch 1/10 验证集 损失：0.1690，准确率：0.9334，F1 分数：0.9333
阶段 2 Epoch 2/10 验证集 损失：0.1461，准确率：0.9428，F1 分数：0.9427
阶段 2 Epoch 3/10 验证集 损失：0.1370，准确率：0.9469，F1 分数：0.9469
阶段 2 Epoch 4/10 验证集 损失：0.1360，准确率：0.9459，F1 分数：0.9459
阶段 2 Epoch 5/10 验证集 损失：0.1251，准确率：0.9480，F1 分数：0.9480
阶段 2 Epoch 6/10 验证集 损失：0.1214，准确率：0.9501，F1 分数：0.9500
阶段 2 Epoch 7/10 验证集 损失：0.1197，准确率：0.9542，F1 分数：0.9542
阶段 

In [None]:
from google.colab import drive
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import numpy as np

# 挂载 Google Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive', force_remount=True)
else:
    print("Google Drive 已挂载")

# 设置路径
drive_path = '/content/drive/MyDrive/33/'
data_path = os.path.join(drive_path, 'Data/')
model_path = os.path.join(drive_path, 'Images/Models/')
feature_save_path = os.path.join(model_path, 'features.csv')

# 自定义数据集类
class ChestXRayDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = os.path.join(drive_path, self.dataframe.iloc[idx]['image_path'])
        try:
            image = Image.open(img_path).convert('L')  # 灰度图
        except Exception as e:
            print(f"无法加载图像 {img_path}：{e}")
            raise
        label = self.dataframe.iloc[idx]['label']
        image_id = self.dataframe.iloc[idx]['image_id']
        if self.transform:
            image = self.transform(image)
        return image, label, image_id, img_path

# 数据变换（与训练一致）
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])

# 加载测试集
try:
    test_df = pd.read_csv(os.path.join(data_path, 'test.csv'))
    print(f"成功加载 test.csv，测试集大小：{len(test_df)}")
except FileNotFoundError as e:
    print(f"错误：未找到 test.csv，{e}")
    raise
except Exception as e:
    print(f"读取 test.csv 时出错：{e}")
    raise

# 检查必要列
required_columns = ['image_id', 'image_path', 'label']
if not all(col in test_df.columns for col in required_columns):
    missing_cols = [col for col in required_columns if col not in test_df.columns]
    print(f"错误：test.csv 缺少以下必要列：{missing_cols}")
    raise ValueError("数据集缺少必要列")

# 创建数据加载器
test_dataset = ChestXRayDataset(test_df, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.mobilenet_v3_small(weights=None)
model.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)  # 单通道输入
num_features = model.classifier[3].in_features
model.classifier[3] = nn.Linear(num_features, 2)  # 二分类
try:
    model.load_state_dict(torch.load(os.path.join(model_path, 'mobilenetv3_final.pth')))
except FileNotFoundError as e:
    print(f"错误：未找到 mobilenetv3_final.pth，{e}")
    raise
model = model.to(device)
model.eval()

# 清理 GPU 内存
torch.cuda.empty_cache()

# 提取特征函数
def extract_features(model, data_loader):
    features_list = []
    labels_list = []
    image_ids_list = []
    with torch.no_grad():
        for inputs, labels, image_ids, _ in data_loader:
            inputs = inputs.to(device)
            # 提取 features 层输出（全局平均池化前）
            features = model.features(inputs)  # 输出 [batch, 576, 7, 7]
            features = torch.nn.functional.adaptive_avg_pool2d(features, (1, 1))  # 全局平均池化 [batch, 576, 1, 1]
            features = features.squeeze(-1).squeeze(-1)  # 扁平化为 [batch, 576]
            features_list.extend(features.cpu().numpy())
            labels_list.extend(labels.numpy())
            image_ids_list.extend(image_ids)
    return image_ids_list, features_list, labels_list

# 提取特征
image_ids, features, labels = extract_features(model, test_loader)

# 保存为 CSV
feature_df = pd.DataFrame({
    'image_id': image_ids,
    'feature_vector': [','.join(map(str, f)) for f in features],  # 转换为逗号分隔字符串
    'label': labels
})
feature_df.to_csv(feature_save_path, index=False)
print(f"特征向量已保存至：{feature_save_path}")

Google Drive 已挂载
成功加载 test.csv，测试集大小：961
特征向量已保存至：/content/drive/MyDrive/33/Images/Models/features.csv


下面是热力图代码

In [None]:
!pip install -q torchxrayvision

from google.colab import drive
import os
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torchxrayvision as xrv

drive.mount('/content/drive')

# ====================== 保存路径 ======================
drive_path = '/content/drive/MyDrive/33/'
control_dir = os.path.join(drive_path, 'Images/QaTa-dataset/control_images')
covid_dir   = os.path.join(drive_path, 'Images/QaTa-dataset/QaTa-COV19')
result_dir  = os.path.join(drive_path, 'Images/Models/')   # ← 你要的路径
os.makedirs(result_dir, exist_ok=True)

# ====================== 加载模型 ======================
print("正在加载 TorchXRayVision PSPNet 模型（肺部分割）...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = xrv.baseline_models.chestx_det.PSPNet()
model.to(device)
model.eval()
print("模型加载完成！")

# ====================== 关键函数：自动中心裁剪成正方形 ======================
def center_crop_to_square(img_np):
    h, w = img_np.shape[:2]
    min_side = min(h, w)
    start_h = (h - min_side) // 2
    start_w = (w - min_side) // 2
    return img_np[start_h:start_h+min_side, start_w:start_w+min_side]

# ====================== 处理函数 ======================
def process_folder(image_folder, label_name, num_images=10):
    print(f"\n正在处理：{label_name}（共 {num_images} 张）")
    all_images = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    selected = all_images[:num_images]

    for filename in tqdm(selected, desc=label_name):
        img_path = os.path.join(image_folder, filename)
        img_pil = Image.open(img_path).convert('L')
        img_np = np.array(img_pil)

        # Step 1: 中心裁剪成正方形
        img_square = center_crop_to_square(img_np)

        # Step 2: 归一化 + 转为 tensor + resize 到 512×512
        img_norm = xrv.datasets.normalize(img_square, 255)
        img_tensor = torch.from_numpy(img_norm).float().unsqueeze(0).unsqueeze(0).to(device)  # (1,1,H,W)
        img_tensor = torch.nn.functional.interpolate(img_tensor, size=(512, 512), mode='bilinear', align_corners=False)

        # Step 3: 推理
        with torch.no_grad():
            output = model(img_tensor)
            prob = torch.softmax(output, dim=1)
        lung_prob = torch.max(prob[0, [4, 5]], dim=0)[0].cpu().numpy()  # 左右肺取最大
        mask_512 = (lung_prob > 0.5).astype(np.uint8) * 255

        # Step 4: 统一输出到 224×224
        img_224 = cv2.resize(img_square, (224, 224))
        mask_224 = cv2.resize(mask_512, (224, 224), interpolation=cv2.INTER_NEAREST)
        mask_224 = cv2.dilate(mask_224, np.ones((3,3), np.uint8), iterations=1)

        # Step 5: 绿色叠加
        overlay = cv2.cvtColor(img_224, cv2.COLOR_GRAY2BGR)
        overlay[mask_224 > 127] = [0, 255, 0]
        blended = cv2.addWeighted(overlay, 0.4, cv2.cvtColor(img_224, cv2.COLOR_GRAY2BGR), 0.6, 0)

        # Step 6: 三联图保存
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(img_224, cmap='gray'); axs[0].set_title('原图 Original'); axs[0].axis('off')
        axs[1].imshow(mask_224, cmap='gray'); axs[1].set_title('肺部掩码 Lung Mask'); axs[1].axis('off')
        axs[2].imshow(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)); axs[2].set_title('绿色叠加 Overlay'); axs[2].axis('off')
        plt.suptitle(f"{filename} - {label_name.upper()}", fontsize=16)
        plt.tight_layout()

        save_name = os.path.splitext(filename)[0] + f"_{label_name}_comparison.png"
        plt.savefig(os.path.join(result_dir, save_name), dpi=150, bbox_inches='tight')
        plt.close()

# ====================== 开始运行 ======================
process_folder(control_dir, "control", num_images=10)
process_folder(covid_dir,   "covid",   num_images=10)

print(f"\n全部完成！20 张高清对比图已保存到：")
print(result_dir)

In [None]:
!pip install -q torchxrayvision tqdm

import os
import torch
import torch.nn.functional as F
import torchxrayvision as xrv
import torchvision.models as models
from torchvision import transforms
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from google.colab import drive

# ====================== 挂载 & 路径设置 ======================
drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
drive_path = '/content/drive/MyDrive/33/'

# 数据文件夹
data_root = os.path.join(drive_path, 'Images/QaTa-dataset/')
control_dir = os.path.join(data_root, 'control_images')
covid_dir   = os.path.join(data_root, 'QaTa-COV19')

# 模型路径
model_path = os.path.join(drive_path, 'Images/Models/mobilenetv3_final.pth')

# 最终输出路径（分 normal / covid 两个子文件夹）
base_heatmap_dir = os.path.join(drive_path, 'Images/Heatmaps')
control_heatmap_dir = os.path.join(base_heatmap_dir, 'control')
covid_heatmap_dir   = os.path.join(base_heatmap_dir, 'covid')
os.makedirs(control_heatmap_dir, exist_ok=True)
os.makedirs(covid_heatmap_dir, exist_ok=True)

print(f"热力图将保存到：")
print(f"  正常组 → {control_heatmap_dir}")
print(f"  肺炎组 → {covid_heatmap_dir}")

# ====================== 1. 加载分类模型 ======================
print("正在加载 MobileNetV3 肺炎分类模型...")
model = models.mobilenet_v3_small(weights=None)
model.features[0][0] = torch.nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)
num_ftrs = model.classifier[3].in_features
model.classifier[3] = torch.nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
target_layer = model.features[-1]

# ====================== 2. 加载肺部分割模型 ======================
print("正在加载 PSPNet 肺部分割模型...")
seg_model = xrv.baseline_models.chestx_det.PSPNet()
seg_model.to(device)
seg_model.eval()

# ====================== 3. 预处理 ======================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])
])

# ====================== Grad-CAM & 掩码函数 ======================
def get_gradcam_heatmap(img_tensor, target_class):
    model.zero_grad()
    gradients, activations = [], []
    def save_grad(grad): gradients.append(grad.detach())
    def save_act(module, input, output): activations.append(output.detach())
    h1 = target_layer.register_forward_hook(save_act)
    h2 = target_layer.register_full_backward_hook(lambda m, gi, go: save_grad(go[0]))
    pred = model(img_tensor)
    score = pred[0, target_class]
    score.backward()
    grads = gradients[0]
    acts = activations[0]
    h1.remove(); h2.remove()
    weights = torch.mean(grads, dim=(2,3), keepdim=True)
    cam = torch.sum(weights * acts, dim=1).squeeze(0)
    cam = F.relu(cam).cpu().numpy()
    cam = cv2.resize(cam, (224, 224))
    cam = np.maximum(cam, 0)
    if cam.max() > 0:
        cam = cam / cam.max()
    return cam

def get_lung_mask(img_pil):
    img_np = np.array(img_pil.convert('L'))
    img_norm = xrv.datasets.normalize(img_np, 255)
    img_tensor = torch.from_numpy(img_norm).float().unsqueeze(0).unsqueeze(0).to(device)
    img_tensor = F.interpolate(img_tensor, size=(512,512), mode='bilinear', align_corners=False)
    with torch.no_grad():
        output = seg_model(img_tensor)
        prob = torch.softmax(output, dim=1)
        lung_prob = torch.max(prob[0, [4,5]], dim=0)[0]
        mask = (lung_prob > 0.5).float().cpu().numpy()
        mask = cv2.resize(mask, (224,224), interpolation=cv2.INTER_NEAREST)
    return (mask > 0.5).astype(np.uint8)

# ====================== 断点续传主函数=====================
def process_folder_with_resume(folder_path, save_dir, label_name, class_idx):
    all_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    # 关键：获取已生成的热力图文件
    existing_files = set()
    if os.path.exists(save_dir):
        for f in os.listdir(save_dir):
            if f.startswith(label_name + "_"):
                # 提取原始文件名：control_xxx.png → xxx.png
                orig_name = f[len(label_name)+1:]  # 去掉 "control_" 或 "covid_"
                existing_files.add(orig_name)

    # 过滤掉已经处理过的
    todo_files = [f for f in all_files if f not in existing_files]

    total = len(all_files)
    done = len(existing_files)
    remain = len(todo_files)

    print(f"\n开始处理 {label_name} 组")
    print(f"  总图像：{total} 张  |  已完成：{done} 张  |  剩余：{remain} 张")

    if remain == 0:
        print(f"  {label_name} 组已全部完成，无需处理！")
        return

    for filename in tqdm(todo_files, desc=f"{label_name} 处理中"):
        img_path = os.path.join(folder_path, filename)
        try:
            img_pil = Image.open(img_path).convert('L')
        except Exception as e:
            print(f"\n跳过损坏图像：{filename} ({e})")
            continue

        img_tensor = transform(img_pil).unsqueeze(0).to(device)
        cam = get_gradcam_heatmap(img_tensor, class_idx)
        lung_mask = get_lung_mask(img_pil)
        cam_masked = cam * lung_mask

        img_np = np.array(img_pil)
        img_224 = cv2.resize(img_np, (224, 224))
        heatmap_color = cv2.applyColorMap(np.uint8(255 * cam_masked), cv2.COLORMAP_JET)
        superimposed = cv2.addWeighted(cv2.cvtColor(img_224, cv2.COLOR_GRAY2BGR), 0.6, heatmap_color, 0.4, 0)

        # 四联图
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        titles = ['原图', '肺部掩码', 'Grad-CAM（仅肺区）', '最终热力图']
        imgs = [img_224, lung_mask * 255, cam_masked, cv2.cvtColor(superimposed, cv2.COLOR_BGR2RGB)]
        cmaps = ['gray', 'gray', 'jet', None]

        for i in range(4):
            if cmaps[i] == 'jet':
                im = axs[i].imshow(imgs[i], cmap='jet', vmin=0, vmax=1)
                plt.colorbar(im, ax=axs[i], fraction=0.046, pad=0.04)
            else:
                axs[i].imshow(imgs[i], cmap=cmaps[i])
            axs[i].set_title(titles[i])
            axs[i].axis('off')

        plt.suptitle(f"{filename} - {label_name.upper()}", fontsize=16)
        plt.tight_layout()

        save_name = f"{label_name}_{filename}"
        plt.savefig(os.path.join(save_dir, save_name), dpi=200, bbox_inches='tight')
        plt.close()

    print(f"{label_name} 组全部处理完成！")
    print(f"  本次新增：{remain} 张  |  累计完成：{done + remain}/{total} 张")

# ====================== 运行（支持断点续传）=====================
process_folder_with_resume(control_dir, control_heatmap_dir, "control", class_idx=0)
process_folder_with_resume(covid_dir,   covid_heatmap_dir,   "covid",   class_idx=1)

print("\n所有热力图生成任务已完成或已存在！")
print(f"查看路径：{base_heatmap_dir}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
热力图将保存到：
  正常组 → /content/drive/MyDrive/33/Images/Heatmaps/control
  肺炎组 → /content/drive/MyDrive/33/Images/Heatmaps/covid
正在加载 MobileNetV3 肺炎分类模型...
正在加载 PSPNet 肺部分割模型...

开始处理 control 组
  总图像：5000 张  |  已完成：5000 张  |  剩余：0 张
  control 组已全部完成，无需处理！

开始处理 covid 组
  总图像：4603 张  |  已完成：3325 张  |  剩余：1278 张


covid 处理中: 100%|██████████| 1278/1278 [4:01:05<00:00, 11.32s/it]

covid 组全部处理完成！
  本次新增：1278 张  |  累计完成：4603/4603 张

所有热力图生成任务已完成或已存在！
查看路径：/content/drive/MyDrive/33/Images/Heatmaps



