In [None]:
import os
import cv2
import numpy as np
import shutil
from sklearn.utils import resample
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models
import csv

# ─── 宏定义 ───────────────────────────────────────────────────────────────────
# 数据集路径
DATA_ROOT = r"/root/"
TRAIN_DIR = os.path.join(DATA_ROOT, "train")
TEST_DIR  = os.path.join(DATA_ROOT, "test")
VAL_DIR   = os.path.join(DATA_ROOT, "val")

# 输出目录
OUTPUT_DIR = r"/root/result"

# 图像处理参数
IMG_SIZE = (224, 224)  # ResNet-50 输入大小
BATCH_SIZE = 32  # 批大小根据显存调整
NUM_EPOCHS = 10  # 训练轮数
LEARNING_RATE = 0.0001  # 学习率
SEED = 114  # 随机种子

# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ─── 图像增强函数 ───────────────────────────────────────────────────────────
def augment_image(image):
    # 随机旋转
    angle = np.random.randint(-10, 10)
    rows, cols = image.shape
    M = cv2.getRotationMatrix2D((cols/2, rows/2), angle, 1)
    rotated = cv2.warpAffine(image, M, (cols, rows))
    
    # 随机翻转
    if np.random.rand() > 0.5:
        flipped = cv2.flip(rotated, 1)  # 水平翻转
    else:
        flipped = rotated
    
    # 随机平移
    tx = np.random.randint(-5, 5)
    ty = np.random.randint(-5, 5)
    M = np.float32([[1, 0, tx], [0, 1, ty]])
    translated = cv2.warpAffine(flipped, M, (cols, rows))
    
    return translated

# ─── 过采样 NORMAL 文件夹 ───────────────────────────────────────────────────
normal_dir = os.path.join(TRAIN_DIR, "NORMAL")
# 确保 normal_dir 存在
assert os.path.exists(normal_dir), f"目录 {normal_dir} 不存在"

# 获取 NORMAL 文件夹中所有图像文件的路径
normal_images = []
for file in os.listdir(normal_dir):
    file_path = os.path.join(normal_dir, file)
    if os.path.isfile(file_path) and file.lower().endswith(('.jpg', '.jpeg', '.png')):
        normal_images.append(file_path)
normal_labels = [0] * len(normal_images)

# 如果没有找到任何图像文件，抛出异常
if not normal_images:
    raise FileNotFoundError(f"在目录 {normal_dir} 中没有找到任何图像文件")

pneumonia_dir = os.path.join(TRAIN_DIR, "PNEUMONIA")
# 确保 pneumonia_dir 存在
assert os.path.exists(pneumonia_dir), f"目录 {pneumonia_dir} 不存在"

# 获取 PNEUMONIA 文件夹中所有图像文件的路径
pneumonia_images = []
for file in os.listdir(pneumonia_dir):
    file_path = os.path.join(pneumonia_dir, file)
    if os.path.isfile(file_path) and file.lower().endswith(('.jpg', '.jpeg', '.png')):
        pneumonia_images.append(file_path)
pneumonia_labels = [1] * len(pneumonia_images)

# 如果没有找到任何图像文件，抛出异常
if not pneumonia_images:
    raise FileNotFoundError(f"在目录 {pneumonia_dir} 中没有找到任何图像文件")

# 计算需要生成的额外 NORMAL 图像数量
num_normal_augmented = len(pneumonia_images) - len(normal_images)

augmented_normal_images = []
for _ in range(num_normal_augmented):
    img_path = np.random.choice(normal_images)
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    augmented_img = augment_image(img)
    augmented_normal_images.append(augmented_img)

augmented_normal_dir = os.path.join(OUTPUT_DIR, "NORMAL_augmented")
os.makedirs(augmented_normal_dir, exist_ok=True)

for i, img in enumerate(augmented_normal_images):
    cv2.imwrite(os.path.join(augmented_normal_dir, f"augmented_NORMAL_{i}.png"), img)

normal_images += [os.path.join(augmented_normal_dir, img) for img in os.listdir(augmented_normal_dir)]
normal_labels = [0] * len(normal_images)

# ─── 欠采样 PNEUMONIA 文件夹 ─────────────────────────────────────────────────
sampled_pneumonia_images, sampled_pneumonia_labels = resample(
    pneumonia_images, pneumonia_labels,
    n_samples=len(normal_images),
    random_state=SEED
)

pneumonia_images = sampled_pneumonia_images
pneumonia_labels = sampled_pneumonia_labels

# ─── 合并过采样和欠采样后的数据集 ────────────────────────────────────────────
all_images = normal_images + pneumonia_images
all_labels = normal_labels + pneumonia_labels

merged_dataset_dir = os.path.join(OUTPUT_DIR, "merged_dataset")
os.makedirs(merged_dataset_dir, exist_ok=True)

normal_merged_dir = os.path.join(merged_dataset_dir, "NORMAL")
pneumonia_merged_dir = os.path.join(merged_dataset_dir, "PNEUMONIA")
os.makedirs(normal_merged_dir, exist_ok=True)
os.makedirs(pneumonia_merged_dir, exist_ok=True)

for img_path in normal_images:
    shutil.copy(img_path, normal_merged_dir)

for img_path in pneumonia_images:
    shutil.copy(img_path, pneumonia_merged_dir)

# ─── 数据变换管道 ──────────────────────────────────────────────────────────
data_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# ─── 加载数据集 ─────────────────────────────────────────────────────────────
# 定义一个函数来清理数据集目录
def clean_dataset_dir(dataset_dir):
    for subdir in os.listdir(dataset_dir):
        subdir_path = os.path.join(dataset_dir, subdir)
        if os.path.isdir(subdir_path):
            # 删除非 NORMAL 或 PNEUMONIA 的子目录
            if subdir not in ["NORMAL", "PNEUMONIA"]:
                shutil.rmtree(subdir_path)
            else:
                # 删除子目录中的无效文件
                for file in os.listdir(subdir_path):
                    file_path = os.path.join(subdir_path, file)
                    if not file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        os.remove(file_path)

# 清理训练集、测试集和验证集目录
clean_dataset_dir(TRAIN_DIR)
clean_dataset_dir(TEST_DIR)
clean_dataset_dir(VAL_DIR)

train_dataset = datasets.ImageFolder(merged_dataset_dir, transform=data_transform)
test_dataset  = datasets.ImageFolder(TEST_DIR,  transform=data_transform)
val_dataset   = datasets.ImageFolder(VAL_DIR,   transform=data_transform)

# 计算每个样本的采样权重，用于 WeightedRandomSampler
class_counts = [0] * len(train_dataset.classes)
for _, label in train_dataset.samples:
    class_counts[label] += 1

weights_per_class = [1.0 / cnt for cnt in class_counts]
sample_weights = [weights_per_class[label] for _, label in train_dataset.samples]
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=4,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# ─── 构建 ResNet-50 模型并替换分类头 ───────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 使用 ResNet-50 模型
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model = model.to(device)

# ─── 损失函数 & 优化器 ──────────────────────────────────────────────────────
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

# ─── 早停机制 ────────────────────────────────────────────────────────────────
early_stopping_patience = 3  # 如果验证损失在 3 个 epoch 内没有改善，则停止训练
best_val_loss = float('inf')
epochs_without_improvement = 0

# ─── 学习率调度器 ───────────────────────────────────────────────────────────
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

# ─── 训练与评估 ──────────────────────────────────────────────────────────────
for epoch in range(1, NUM_EPOCHS + 1):
    # 训练阶段
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.float().to(device)
        optimizer.zero_grad()
        logits = model(imgs).squeeze(1)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        # 计算训练损失和准确率
        train_loss += loss.item() * imgs.size(0)
        preds = torch.sigmoid(logits).round()
        train_correct += (preds == labels).sum().item()
        train_total += imgs.size(0)
    avg_train_loss = train_loss / train_total
    train_accuracy = train_correct / train_total

    # 验证阶段
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.float().to(device)
            logits = model(imgs).squeeze(1)
            loss = criterion(logits, labels)
            val_loss += loss.item() * imgs.size(0)
            preds = torch.sigmoid(logits).round()
            val_correct += (preds == labels).sum().item()
            val_total += imgs.size(0)
    avg_val_loss = val_loss / val_total
    val_accuracy = val_correct / val_total

    print(f"Epoch {epoch}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

    # 早停机制
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch} epochs.")
            break

    # 调整学习率
    scheduler.step(avg_val_loss)

# ─── 测试集预测并生成 submission.csv ────────────────────────────────────────
model.eval()
predictions = []
image_paths = []
test_correct = 0
test_total = 0
current_sample_index = 0  # 用于跟踪当前批次的起始索引

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.float().to(device)
        logits = model(imgs).squeeze(1)
        preds = torch.sigmoid(logits).round().cpu().numpy()
        predictions.extend(preds)
        
        # 获取当前批次的图像路径
        batch_size = len(labels)
        for i in range(batch_size):
            img_path = test_dataset.samples[current_sample_index + i][0]
            image_paths.append(img_path)
        
        # 计算测试准确率
        test_correct += (preds == labels.cpu().numpy()).sum().item()
        test_total += batch_size
        
        current_sample_index += batch_size  # 更新起始索引

test_accuracy = test_correct / test_total

output_path = os.path.join(OUTPUT_DIR, 'submission.csv')
with open(output_path, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['Image Path', 'Prediction'])
    for img_path, pred in zip(image_paths, predictions):
        writer.writerow([img_path, int(pred)])

print(f"submission.csv 文件已生成并保存到 {output_path}")
print(f"Test Accuracy: {test_accuracy:.4f}")

AssertionError: 目录 /root/train\NORMAL 不存在