# 使用 Pytorch 以 ADE20K 训练集训练 DeepLabV3-RESNET101

## 1. 基础准备

### 1.1 （第一次或在虚拟环境运行）安装相关库


In [None]:
%pip install numpy<2.0
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install opencv-python
%pip install matplotlib
%pip install tqdm

### 1.2 检查 CUDA 版本


In [None]:
!nvidia-smi

## 2. 搭建模型

### 2.1 导入相关库


In [7]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models.segmentation as models
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

### 2.2 定义数据集类

数据集可在 [ADE20K](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip) 下载


In [8]:
DATASETPATH = Path("dataset/ade20k/ADEChallengeData2016")  # 数据集位置

class ADE20kDataset(Dataset):
    def __init__(self, transforms=None, mask_transforms=None):
        self.transforms = transforms
        self.mask_transforms = mask_transforms
        self.images = sorted(os.listdir(os.path.join(DATASETPATH, "images", "training")))
        self.masks = sorted(os.listdir(os.path.join(DATASETPATH, "annotations", "training")))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(DATASETPATH, "images", "training", self.images[idx])
        mask_path = os.path.join(DATASETPATH, "annotations", "training", self.masks[idx])

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transforms:
            image = Image.fromarray(image)
            image = self.transforms(image)
        if self.mask_transforms:
            mask = Image.fromarray(image)
            mask = self.mask_transforms(mask)
            mask = mask.squeeze(0)
        return image, mask

### 2.3 数据预处理和加载


In [9]:
data_transforms = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transforms = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

train_dataset = ADE20kDataset(transforms=data_transforms, mask_transforms=mask_transforms)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)

#### 2.3.1 检查数据


In [10]:
# Display image and label.
train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

### 2.4 模型预加载

此处模型使用 Pytorch 中的 DeepLabV3Plus 模型


In [None]:
model = models.deeplabv3_resnet101(pretrained=False, num_classes=151) 

# 转移到cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cuda')
model = model.to(device)

### 2.5 训练模型


In [None]:
import torch.optim as optim
import torch.nn as nn

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, masks in tqdm(train_loader):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, masks.long())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

print('Training complete')

### 2.6 保存模型


In [None]:
torch.save(model.state_dict(), 'deeplabv3plus_ade20k.pth')

### 2.7 测试模型


In [None]:
model.load_state_dict(torch.load('deeplabv3plus_ade20k.pth'))
model = model.to(device)
model.eval()

# 加载测试图片
test_image = Image.open('path_to_test_image').convert("RGB")
test_image = data_transforms(test_image).unsqueeze(0).to(device)

with torch.no_grad():
    output = model(test_image)['out'][0]
output_predictions = output.argmax(0)

# 可视化结果
import matplotlib.pyplot as plt

plt.imshow(output_predictions.cpu().numpy())
plt.show()