In [4]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# 数据预处理（标准化 + 数据增强）
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

# 加载PASCAL VOC 2012数据集（仅需图像和语义分割mask）
voc_train = datasets.VOCSegmentation(
    root='./data', year='2012', image_set='train',
    transform=transform, target_transform=transforms.ToTensor(), download=True
)
voc_val = datasets.VOCSegmentation(
    root='./data', year='2012', image_set='val',
    transform=transform, target_transform=transforms.ToTensor(), download=True
)

# 划分训练集、验证集、测试集（按7:2:1比例）
train_loader = DataLoader(voc_train, batch_size=8, shuffle=True)
val_loader = DataLoader(voc_val, batch_size=8, shuffle=False)
test_loader = DataLoader(voc_val, batch_size=8, shuffle=False)  # 测试集直接使用验证集

In [5]:
import torch.nn as nn
import torchvision.models as models

class LightWeightSegmentation(nn.Module):
    def __init__(self, num_classes=21):  # PASCAL VOC 21个类别（含背景）
        super(LightWeightSegmentation, self).__init__()
        # 使用预训练ResNet18作为编码器
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])  # 去除最后两层
        
        # 解码器（上采样+卷积）
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, num_classes, kernel_size=2, stride=2)  # 输出类别数
        )
        
    def forward(self, x):
        x = self.encoder(x)  # 编码器输出特征图
        x = self.decoder(x)  # 解码器还原到输入尺寸
        return x

model = LightWeightSegmentation()

In [6]:
import torch.optim as optim
from torchmetrics import IoU

# 超参数
criterion = nn.CrossEntropyLoss(ignore_index=255)  # 忽略无效标签
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
num_epochs = 10

# 训练函数
def train(model, dataloader, epoch):
    model.train()
    total_loss = 0
    for images, masks in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.squeeze(1))  # 掩码需要压缩通道维度
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} Train Loss: {total_loss/len(dataloader)}")

# 验证函数（计算IoU和准确率）
def validate(model, dataloader):
    model.eval()
    iou_metric = IoU(num_classes=21)
    correct, total = 0, 0
    with torch.no_grad():
        for images, masks in dataloader:
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            iou_metric.update(preds, masks.squeeze(1))
            correct += (preds == masks.squeeze(1)).sum().item()
            total += masks.numel()
    print(f"Validation Accuracy: {correct/total:.4f}, IoU: {iou_metric.compute():.4f}")

# 开始训练
for epoch in range(1, num_epochs+1):
    train(model, train_loader, epoch)
    validate(model, val_loader)
    scheduler.step()

# 测试集最终评估
validate(model, test_loader)

ImportError: cannot import name 'IoU' from 'torchmetrics' (C:\Users\Change\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\torchmetrics\__init__.py)

In [None]:
import matplotlib.pyplot as plt

def visualize_predictions(model, dataloader, num_samples=3):
    model.eval()
    with torch.no_grad():
        for images, masks in dataloader:
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            for i in range(num_samples):
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 3, 1)
                plt.title("Input Image")
                plt.imshow(images[i].permute(1, 2, 0).cpu().numpy())
                
                plt.subplot(1, 3, 2)
                plt.title("True Mask")
                plt.imshow(masks[i].squeeze().cpu().numpy(), cmap='gray')
                
                plt.subplot(1, 3, 3)
                plt.title("Predicted Mask")
                plt.imshow(preds[i].cpu().numpy(), cmap='gray')
                plt.show()
            break

visualize_predictions(model, test_loader)