In [1]:
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset

创建自定义数据集类

In [2]:
from torchvision import transforms

Transform = transforms.Compose([
    transforms.Resize((512, 512)),
    # other transforms, if needed
])

class BreastCancerSegmentationDataset(Dataset):
    """
    乳腺癌分割数据集
    """
    def __init__(self, img_dir, mask_dir, transform=None, one_hot_encode=True, target_size=(256, 256)):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.one_hot_encode = one_hot_encode
        self.target_size = target_size
        self.img_filenames = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, index):
        img_name = self.img_filenames[index]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name[:-4] + '_mask'+img_name[-4:])
        # Skip .ipynb_checkpoints files
        if img_path.endswith(".ipynb_checkpoints") or mask_path.endswith(".ipynb_checkpoints"):
            return self.__getitem__((index + 1) % len(self))
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if image is None:
            raise FileNotFoundError(f"Image not found at {img_path}")
        if mask is None:
            raise FileNotFoundError(f"Mask not found at {mask_path}")

        # Resize image and mask
        image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)

        if self.one_hot_encode:
            mask = one_hot_encode(mask, num_classes=3)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        mask = np.asarray(mask)   # 转换为NumPy数组
        mask = mask.astype(np.float32)   # 变换dtype
        mask = torch.from_numpy(mask) # 转换为Tensor
        return image, mask

class BreastCancerClassificationDataset(Dataset):
    """
    乳腺癌分类数据集
    """
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_filenames = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, index):
        img_name = self.img_filenames[index]
        img_path = os.path.join(self.img_dir, img_name)
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)

        if image is None:
            raise FileNotFoundError(f"Image not found at {img_path}")

        if self.transform:
            image = self.transform(image=image)["image"]


        # 从图像名称解析分类标签
        if 'benign' in img_name:
            label = 0
        elif 'malignant' in img_name:
            label = 1
        elif 'normal' in img_name:
            label = 2

        return image, label

def one_hot_encode(mask, num_classes):
    one_hot = np.zeros((mask.shape[0], mask.shape[1], num_classes), dtype=np.uint8)
    for c in range(num_classes):
        one_hot[..., c] = (mask == c)
    return one_hot


使用自定义数据集类加载数据并进行数据增强
划分数据集为训练集和验证集和测试集

In [3]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Rotate(limit=30, p=0.3),
    A.RandomResizedCrop(height=256, width=256, scale=(0.8, 1.0), p=0.2),
    A.Normalize(),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(),
    ToTensorV2()
])

test_transform = A.Compose([
    A.Normalize(),
    ToTensorV2()
])

dataset = {
    'train': BreastCancerSegmentationDataset("image", "mask", transform=train_transform),
    'val': BreastCancerSegmentationDataset("image", "mask", transform=val_transform),
    'test': BreastCancerSegmentationDataset("image", "mask", transform=test_transform)
}

# 打印一个样本，检查数据格式和变换是否符合预期
dataset['train'][10]


(tensor([[[ 0.3823,  0.3138,  0.2282,  ...,  0.2453,  0.2796,  0.1426],
          [ 0.3309,  0.3138,  0.3994,  ...,  0.2282,  0.2624,  0.6906],
          [ 0.5364,  0.4166,  0.3823,  ...,  0.4679,  0.5536,  0.6563],
          ...,
          [-1.8610, -1.8953, -1.9295,  ..., -1.7754, -1.7583, -1.8268],
          [-1.7412, -1.7583, -1.8782,  ..., -1.8097, -1.8097, -1.8610],
          [-1.9124, -1.9467, -1.9638,  ..., -1.8439, -1.8439, -1.8782]],
 
         [[ 0.5203,  0.4503,  0.3627,  ...,  0.3803,  0.4153,  0.2752],
          [ 0.4678,  0.4503,  0.5378,  ...,  0.3627,  0.3978,  0.8354],
          [ 0.6779,  0.5553,  0.5203,  ...,  0.6078,  0.6954,  0.8004],
          ...,
          [-1.7731, -1.8081, -1.8431,  ..., -1.6856, -1.6681, -1.7381],
          [-1.6506, -1.6681, -1.7906,  ..., -1.7206, -1.7206, -1.7731],
          [-1.8256, -1.8606, -1.8782,  ..., -1.7556, -1.7556, -1.7906]],
 
         [[ 0.7402,  0.6705,  0.5834,  ...,  0.6008,  0.6356,  0.4962],
          [ 0.6879,  0.6705,

定义数据加载器

In [4]:
from torch.utils.data import DataLoader

# 构建训练集、验证集和测试集的 DataLoader
train_loader = DataLoader(dataset['train'], batch_size=16, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(dataset['val'], batch_size=32, shuffle=False, num_workers=0, drop_last=True)
test_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False, num_workers=0, drop_last=True)


实例化U-Net模型

In [5]:
import torch
import torch.nn as nn
from torchvision.models.segmentation import fcn_resnet50
from sklearn.metrics import f1_score
import time

class UNetTrainer:
    def __init__(self, num_classes=3, lr=1e-4):
        self.model = self.build_model(num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def build_model(self, num_classes):
        model = fcn_resnet50(pretrained=False, num_classes=num_classes)
        return model

    def evaluate(self, epoch, dataloader, phase):
        if phase == "train":
            self.model.train()
        else:
            self.model.eval()

        running_loss = 0.0
        running_corrects = 0
        running_f1_score = 0.0

        for images, masks in dataloader:
            images = images.to(self.device)
            masks = masks.to(self.device)

            # 由于数据集中的mask是三维的，需要转换为二维
            masks = torch.mean(masks, dim=3, keepdim=False).long()

            self.optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                outputs = self.model(images)['out']
                preds = torch.argmax(outputs, dim=1)
                loss = self.criterion(outputs, masks)

                if phase == "train":
                    loss.backward()
                    self.optimizer.step()

            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(preds == masks.data)
            running_f1_score += f1_score(masks.cpu().numpy().ravel(), preds.cpu().numpy().ravel(), average="macro")

        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_acc = running_corrects.double() / len(dataloader.dataset)
        epoch_f1_score = running_f1_score / len(dataloader)

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} F1: {epoch_f1_score:.4f}")

    def train(self, num_epochs, train_loader, val_loader):
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs}")
            print("-" * 10)

            start_time = time.time()

            self.evaluate(epoch, train_loader, "train")
            self.evaluate(epoch, val_loader, "val")

            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Epoch time: {elapsed_time:.4f}s")

        print("Training complete")

trainer = UNetTrainer()
trainer.train(20, train_loader, val_loader)


Epoch 1/20
----------
train Loss: 0.3580 Acc: 56418.0064 F1: 0.3979
val Loss: 0.0817 Acc: 64445.1319 F1: 1.0000
Epoch time: 126.1678s
Epoch 2/20
----------
train Loss: 0.0506 Acc: 64434.2817 F1: 0.5764
val Loss: 0.0400 Acc: 64434.9001 F1: 0.7152
Epoch time: 126.2510s
Epoch 3/20
----------
train Loss: 0.0351 Acc: 64430.4251 F1: 0.7430
val Loss: 0.0281 Acc: 64445.1319 F1: 1.0000
Epoch time: 128.4155s
Epoch 4/20
----------
train Loss: 0.0229 Acc: 64442.9744 F1: 0.7778
val Loss: 0.0196 Acc: 64445.1319 F1: 1.0000
Epoch time: 127.4309s
Epoch 5/20
----------
train Loss: 0.0168 Acc: 64443.9424 F1: 0.8229
val Loss: 0.0146 Acc: 64445.1319 F1: 1.0000
Epoch time: 127.5493s
Epoch 6/20
----------
train Loss: 0.0132 Acc: 64444.6364 F1: 0.8472
val Loss: 0.0118 Acc: 64445.1319 F1: 1.0000
Epoch time: 128.8598s
Epoch 7/20
----------
train Loss: 0.0107 Acc: 64444.7234 F1: 0.9167
val Loss: 0.0098 Acc: 64445.1319 F1: 1.0000
Epoch time: 127.2442s
Epoch 8/20
----------
train Loss: 0.0089 Acc: 64444.8988 F1: 0