# 使用 data_package 做图像分类示例（COVID-CT）

本 Notebook 演示如何使用 `data_package` 中的 `ImageClassificationDataset`
和 `transforms` 来完成：

1. 构建图片路径和标签；
2. 划分 train / val / test；
3. 定义预处理和数据增强；
4. 构建 DataLoader，并简单跑一个训练循环。

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from pathlib import Path

from datasets import ImageClassificationDataset
from transforms import (
    Compose, Resize, RandomHorizontalFlip, RandomVerticalFlip,RandomRotate90,
    ToTensor, Normalize,
)
from utils import train_val_test_split,compute_channel_mean_std

In [2]:
root = Path("../data/COVID-CT")
classes = ["COVID", "non-COVID"]  # label 0, 1

image_paths = []
labels = []

for label, cls in enumerate(classes):
    cls_dir = root / cls
    for p in cls_dir.glob("*.png"):
        image_paths.append(str(p))
        labels.append(label)

print("总图片数:", len(image_paths))
print("前 5 张:", image_paths[:5])
print("前 5 个标签:", labels[:5])

总图片数: 2481
前 5 张: ['../data/COVID-CT/COVID/Covid (1035).png', '../data/COVID-CT/COVID/Covid (1143).png', '../data/COVID-CT/COVID/Covid (513).png', '../data/COVID-CT/COVID/Covid (1054).png', '../data/COVID-CT/COVID/Covid (864).png']
前 5 个标签: [0, 0, 0, 0, 0]


In [3]:
(
    train_images, val_images, test_images,
    train_labels, val_labels, test_labels
) = train_val_test_split(
    image_paths, labels,
    train_ratio=0.7, val_ratio=0.15, test_ratio=0.15,
    shuffle=True, seed=42,
)

print("train / val / test:", len(train_images), len(val_images), len(test_images))

train / val / test: 1736 372 373


In [4]:
# 1. 建一个“没有 Normalize 的”临时 Dataset
tmp_tf = Compose([
    Resize((224, 224)),
    ToTensor(mask_mode="none"),  # 重要：还不能 Normalize
])

tmp_train_ds = ImageClassificationDataset(
    image_paths=train_images,
    labels=train_labels,
    transform=tmp_tf,
)

# 2. 估计 mean / std（可以只用前几百个 batch）
mean, std = compute_channel_mean_std(tmp_train_ds, batch_size=16, max_batches=50)
print("mean:", mean)
print("std :", std)

mean: tensor([0.6400, 0.6400, 0.6400])
std : tensor([0.2713, 0.2713, 0.2713])


In [5]:
# 训练集：带数据增强
train_transform = Compose([
    Resize((224, 224)),
    RandomHorizontalFlip(p=0.5),
    RandomVerticalFlip(p=0.5),
    ToTensor(),  
    Normalize(
        mean= mean.tolist(),
        std=std.tolist(),
    ),
])

# 验证 / 测试集：只做确定性变换
eval_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(
        mean=mean.tolist(),
        std=std.tolist(),
    ),
])


In [6]:
train_ds = ImageClassificationDataset(
    image_paths=train_images,
    labels=train_labels,
    transform=train_transform,
)

val_ds = ImageClassificationDataset(
    image_paths=val_images,
    labels=val_labels,
    transform=eval_transform,
)

test_ds = ImageClassificationDataset(
    image_paths=test_images,
    labels=test_labels,
    transform=eval_transform,
)

batch_size = 16

train_loader = DataLoader(train_ds, batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size, shuffle=False)

for batch in train_loader:
    print("batch keys:", batch.keys())
    print("image shape:", batch["image"].shape)  # [B, 3, 224, 224]
    print("label shape:", batch["label"].shape)  # [B]
    meta = batch["meta"]  # dict: {"image_path": [...], "index": ...}

    # 取第 0 个样本对应的 meta
    meta_example = {k: v[0] for k, v in meta.items()}
    print("meta example:", meta_example)
    break


batch keys: dict_keys(['image', 'label', 'meta'])
image shape: torch.Size([16, 3, 224, 224])
label shape: torch.Size([16])
meta example: {'path': '../data/COVID-CT/COVID/Covid (37).png', 'index': tensor(1074)}


In [7]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 112x112

            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 56x56

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),  # [B, 64, 1, 1]
        )
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0), -1)  # [B, 64]
        x = self.fc(x)
        return x


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(classes)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [8]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0

    for batch in loader:
        imgs = batch["image"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total += imgs.size(0)

    return total_loss / total, total_correct / total


def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total = 0

    with torch.no_grad():
        for batch in loader:
            imgs = batch["image"].to(device)
            labels = batch["label"].to(device)

            logits = model(imgs)
            loss = criterion(logits, labels)

            total_loss += loss.item() * imgs.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total += imgs.size(0)

    return total_loss / total, total_correct / total

In [9]:
for epoch in range(3):  # 跑几轮看看流程
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = eval_one_epoch(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}: "
          f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
          f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

Epoch 1: train_loss=0.6782, train_acc=0.5639, val_loss=0.6543, val_acc=0.6478
Epoch 2: train_loss=0.6386, train_acc=0.6262, val_loss=0.6509, val_acc=0.5968
Epoch 3: train_loss=0.6039, train_acc=0.6642, val_loss=0.5793, val_acc=0.6425
