In [1]:
import torch
from torch import nn
import torchvision
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

预处理

In [2]:
import os


dataset_path = "../../datasets/VOC2012"
TRAIN_PATH = "ImageSets/Segmentation/train.txt"
TEST_PATH = "ImageSets/Segmentation/val.txt"

with open(os.path.join(dataset_path, TRAIN_PATH)) as f:
    seg_image_list = [string for string in f.read().split('\n') if string]

print("Training dataset size:", len(seg_image_list))

Training dataset size: 1464


构造随机裁剪函数，将图片裁成宽 480，高 320 的大小

In [3]:
from typing import Tuple


def random_crop(img: torch.Tensor, label: torch.Tensor, size: Tuple[int, int]=(320, 480)):
    i, j, h, w = torchvision.transforms.RandomCrop.get_params(img, size)
    cropped_img = torchvision.transforms.functional.crop(img, i, j, h, w)
    cropped_label = torchvision.transforms.functional.crop(label, i, j, h, w)

    return cropped_img, cropped_label

In [4]:
VOC_COLORS = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128], [224, 224, 192]]

VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor', 'edge']

COLORMAP = dict(zip([tuple(lst) for lst in VOC_COLORS], range(22)))
COLORMAP[(224, 224, 192)] = 0
COLORMAP

{(0, 0, 0): 0,
 (128, 0, 0): 1,
 (0, 128, 0): 2,
 (128, 128, 0): 3,
 (0, 0, 128): 4,
 (128, 0, 128): 5,
 (0, 128, 128): 6,
 (128, 128, 128): 7,
 (64, 0, 0): 8,
 (192, 0, 0): 9,
 (64, 128, 0): 10,
 (192, 128, 0): 11,
 (64, 0, 128): 12,
 (192, 0, 128): 13,
 (64, 128, 128): 14,
 (192, 128, 128): 15,
 (0, 64, 0): 16,
 (128, 64, 0): 17,
 (0, 192, 0): 18,
 (128, 192, 0): 19,
 (0, 64, 128): 20,
 (224, 224, 192): 0}

In [5]:
def label_preprocessing(label: torch.Tensor):
    int_label = torch.zeros(label.shape[1:])
    for x, row in enumerate(label.permute(1, 2, 0)):
        for y, color in enumerate(row):
            int_label[x, y] = COLORMAP[tuple(color.tolist())]
    
    return int_label
    

In [6]:
TRAIN_PATH = "ImageSets/Segmentation/train.txt"
TEST_PATH = "ImageSets/Segmentation/val.txt"
IMAGE_PATH = "JPEGImages"
SEGMENTATION_LABEL_PATH = "SegmentationClass"


class VocDataset(Dataset):
    def __init__(self, dataset_path: str, train: bool=True, shape: Tuple[int, int]=(320, 480)):
        if train:
            with open(os.path.join(dataset_path, TRAIN_PATH)) as f:
                file_list = [string for string in f.read().split('\n') if string]
        else:
            with open(os.path.join(dataset_path, TEST_PATH)) as f:
                file_list = [string for string in f.read().split('\n') if string]

        self.image_list = []
        self.label_list = []
        self.shape = shape

        for file in file_list:
            img = torchvision.io.read_image(
                os.path.join(dataset_path, IMAGE_PATH, file + '.jpg'),
                mode=torchvision.io.image.ImageReadMode.RGB)
            if img.shape[1] < shape[0] or img.shape[2] < shape[1]:
                continue
            self.image_list.append(img)
            label = torchvision.io.read_image(
                os.path.join(dataset_path, SEGMENTATION_LABEL_PATH, file + '.png'),
                mode=torchvision.io.image.ImageReadMode.RGB)
            self.label_list.append(label)
        
        self.size = len(self.image_list)
        print("Read {} images.".format(self.size))

    def __getitem__(self, idx: int):
        cropped_img, cropped_label = random_crop(
            self.image_list[idx], self.label_list[idx], self.shape
        )
        # int_label = cropped_label
        int_label = label_preprocessing(cropped_label)
        return cropped_img, int_label
    
    def __len__(self):
        return self.size


In [7]:
train_dataset = VocDataset(dataset_path, train=True)
test_dataset = VocDataset(dataset_path, train=False)

Read 1114 images.
Read 1078 images.


In [8]:
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)

for x, y in train_loader:
    break

print(x.shape, y.shape)

torch.Size([3, 3, 320, 480]) torch.Size([3, 320, 480])


准确分类像素数计数

In [9]:
def count_accurate(y_hat: torch.Tensor, y: torch.Tensor):
    return int(torch.sum(y_hat.argmax(dim=1) == y))

构建网络，进行分类

In [10]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)


class DownConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.down = nn.MaxPool2d(kernel_size=2)
    
    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.down(x1)
        return x1, x2


class UpConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x, x_copy):
        x_up = self.up(x)
        x_concat = torch.concat([x_up, x_copy], dim=1)
        return self.conv(x_concat)


class UNet(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        self.down1 = DownConv(3, 64)
        self.down2 = DownConv(64, 128)
        self.down3 = DownConv(128, 256)
        self.down4 = DownConv(256, 512)
        self.bottleneck = DoubleConv(512, 1024)
        self.up1 = UpConv(1024, 512)
        self.up2 = UpConv(512, 256)
        self.up3 = UpConv(256, 128)
        self.up4 = UpConv(128, 64)
        self.out = nn.Conv2d(64, num_classes, kernel_size=3, padding=1)
    
    def forward(self, x):
        x1, x1_down = self.down1(x)
        x2, x2_down = self.down2(x1_down)
        x3, x3_down = self.down3(x2_down)
        x4, x4_down = self.down4(x3_down)
        y0 = self.bottleneck(x4_down)
        y1 = self.up1(y0, x4)
        y2 = self.up2(y1, x3)
        y3 = self.up3(y2, x2)
        y4 = self.up4(y3, x1)
        return self.out(y4)

使用 GPU，准备训练

In [11]:
gpu = torch.device("cuda")
gpu

device(type='cuda')

In [12]:
unet = UNet(3, 21).to(gpu)

In [13]:
lr = 0.01
num_epochs = 20
optimizer = torch.optim.Adam(unet.parameters(), lr)

In [14]:
# cross_entropy = nn.CrossEntropyLoss()
# def loss(prediction: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
#     l = cross_entropy(prediction.permute(0, 2, 3, 1).flatten(0, 2), label.flatten(0, 2))
#     return l

In [15]:
import json

with open("./num-classes.json") as f:
    num_classes = json.load(f)
total_num = sum(num_classes)
weights = [total_num / num_class for num_class in num_classes]


cross_entropy = nn.CrossEntropyLoss(weight=torch.Tensor(weights).to(gpu))
def loss(prediction: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
    l = cross_entropy(prediction.permute(0, 2, 3, 1).flatten(0, 2), label.flatten(0, 2))
    return l

weights = [weight / sum(weights) for weight in weights]
weights

[0.00059640177741287,
 0.06576704490683068,
 0.15442672573680793,
 0.052459913629721334,
 0.07733380319113269,
 0.07718465950793414,
 0.02676270592391312,
 0.03350841070995501,
 0.017342202841490108,
 0.04092966103497042,
 0.0568208376433905,
 0.03462928101585405,
 0.026951623808923176,
 0.05127708762700316,
 0.04053930025233477,
 0.009761997318590397,
 0.07010757379946872,
 0.051942961503561855,
 0.03241862152709714,
 0.029391689156171033,
 0.04984749708743679]

In [16]:
last_epoch = 0

# checkpoint = torch.load("./parameter_epoch_1_lr_0.01_weighted_loss_2.tar")
# unet.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# last_epoch = checkpoint['epoch']

In [17]:
# num_accurate = 0
# num_total = 0

# with torch.no_grad():
#     for idx, (x, y) in enumerate(test_loader):
#         x: torch.Tensor
#         y: torch.Tensor
#         x = x.float().to(gpu)
#         y = y.long().to(gpu)
#         num_total += int(y.shape[1] * y.shape[2])
#         y_hat = unet(x)
#         num_accurate += count_accurate(y_hat, y)

#         if (idx + 1) % 10 == 0:
#             print(idx + 1)

In [18]:
for epoch in range(last_epoch, num_epochs):
    unet.zero_grad()
    for idx, (x, y) in enumerate(train_loader):
        x: torch.Tensor
        y: torch.Tensor
        x = x.float().to(gpu)
        y = y.long().to(gpu)
        y_hat = unet(x)
        l = loss(y_hat, y)
        l.backward()
        optimizer.step()

        if (idx + 1) % 20 == 0:
            print(f"{idx + 1} batches passed")
    
    # Count accuracy
    num_accurate = 0
    num_total = 0
    
    with torch.no_grad():
        for idx, (x, y) in enumerate(test_loader):
            x: torch.Tensor
            y: torch.Tensor
            x = x.float().to(gpu)
            y = y.long().to(gpu)
            num_total += int(y.shape[1] * y.shape[2])
            y_hat = unet(x)
            num_accurate += count_accurate(y_hat, y)
            if (idx + 1) % 100 == 0:
                print(f"{idx + 1} test images passed.")
    
    checkpoint = {
        'model': unet.state_dict(),
        'optimizer': optimizer.state_dict(),
        'accuracy': num_accurate / num_total,
        'epoch': epoch + 1
    }
    torch.save(checkpoint, f"./parameter_epoch_{epoch + 1}_lr_{lr}_weighted_loss_bn.tar")
    
    print("Epoch: {}, Loss: {}, Validation accuracy: {}".format(
          epoch + 1, float(l), num_accurate / num_total)) 


20 batches passed
40 batches passed
60 batches passed
80 batches passed
100 batches passed
120 batches passed
140 batches passed
160 batches passed
180 batches passed
200 batches passed
220 batches passed
240 batches passed
260 batches passed
280 batches passed
300 batches passed
320 batches passed
340 batches passed
360 batches passed
100 test images passed.
200 test images passed.
300 test images passed.
400 test images passed.
500 test images passed.
600 test images passed.
700 test images passed.
800 test images passed.
900 test images passed.
1000 test images passed.
Epoch: 1, Loss: 2.58341646194458, Validation accuracy: 0.014875438456632653
20 batches passed
40 batches passed
60 batches passed
80 batches passed
100 batches passed


KeyboardInterrupt: 