In [10]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as tfs
import matplotlib.pyplot as plt
import cv2
import glob
import os
from PIL import Image
import random
print(torch.__version__)
print(cv2.__version__)

1.3.0
4.1.2


数据预处理参考 https://zhuanlan.zhihu.com/p/32506912

In [30]:
def read_images(root=r'../input/pascal-voc-2012/VOC2012', train=True):
    filename = root + '/ImageSets/Segmentation/' + ('train.txt' if train else 'val.txt')
    with open(filename, 'r') as f:
        images = f.read().split()
    data = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in images]
    label = [os.path.join(root, 'SegmentationClass', i + '.png') for i in images]
    return data, label

def rand_crop(data, label, height, width):
    '''
    data is PIL.Image object
    label is PIL.Image object
    '''
    w, h = data.size
    if w == width and h == height:
        i, j = 0, 0
    else:
        i = random.randint(0, h - height)
        j = random.randint(0, w - width)
    
    data = tfs.functional.crop(data, i, j, height, width)
    label = tfs.functional.crop(label, i, j, height, width)
    return data, label

In [3]:
classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']

# RGB color for each class
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

len(classes), len(colormap)

(21, 21)

In [4]:
cm2lbl = np.zeros(256**3)
for i,cm in enumerate(colormap):
    cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i

def image2label(im):
    data = np.array(im, dtype='int32')
    idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
    return np.array(cm2lbl[idx], dtype='int64')

In [5]:
label_im = Image.open('../input/pascal-voc-2012/VOC2012/SegmentationClass/2007_000033.png').convert('RGB')
label = image2label(label_im)
label[150:160, 240:250]

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [22]:
def img_transforms(im, label, crop_size):
    im, label = rand_crop(im, label, crop_size[0], crop_size[1])
    im_tfs = tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    im = im_tfs(im)
    label = image2label(label)
    label = torch.from_numpy(label)
    return im, label

In [7]:
class VOCSegDataset(torch.utils.data.Dataset):
    '''
    voc dataset
    '''
    def __init__(self, train, crop_size, transforms):
        self.crop_size = crop_size
        self.transforms = transforms
        data_list, label_list = read_images(train=train)
        self.data_list = self._filter(data_list)
        self.label_list = self._filter(label_list)
        print('Read ' + str(len(self.data_list)) + ' images')
        
    def _filter(self, images):
        return [im for im in images if (Image.open(im).size[1] >= self.crop_size[0] and 
                                        Image.open(im).size[0] >= self.crop_size[1])]
        
    def __getitem__(self, idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img = Image.open(img)
        label = Image.open(label).convert('RGB')
        img, label = self.transforms(img, label, self.crop_size)
        return img, label
    
    def __len__(self):
        return len(self.data_list)

In [11]:
class CityscapesDataset(torch.utils.data.Dataset):
    
    def __init__(self, root_dir):
        self.data_dir = glob.glob(root_dir + r'*.jpg')
        
    def __len__(self):
        return len(self.data_dir)
    
    def __getitem__(self, idx):
        '''
        Args: scalar index
        Returns: input_img, ground_truth
        '''
        img = cv2.imread(self.data_dir[idx])
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img, gt = torch.tensor(img[:, 0:256, :], dtype=torch.float), torch.tensor(img[:, 256:512, :], dtype=torch.float)
        img, gt = img.permute(2, 0, 1), gt.permute(2, 0, 1)
        torchvision.transforms.Normalize([73.27711, 82.75635, 72.56262], [44.13835, 45.155933, 44.69525], inplace=True)(img)
        return img, gt

In [31]:
# train = CityscapesDataset(r'../input/cityscapes-image-pairs/cityscapes_data/train/')
# val = CityscapesDataset(r'../input/cityscapes-image-pairs/cityscapes_data/val/')
input_shape = (320, 480)
trainSet = VOCSegDataset(True, input_shape, img_transforms)
testSet = VOCSegDataset(False, input_shape, img_transforms)

Read 1114 images
Read 1078 images


In [32]:
train_data = torch.utils.data.DataLoader(trainSet, 64, shuffle=True, num_workers=4)
valid_data = torch.utils.data.DataLoader(testSet, 128, num_workers=4)

In [14]:
# calculate mean and std
# mean, std = torch.zeros(3), torch.zeros(3)
# for img, gt in train:
#     mean += torch.mean(img, dim=(1, 2))
#     std += torch.std(img, dim=(1, 2))
# mean = (mean / len(train)).numpy()
# std = (std / len(train)).numpy()
# print(mean, std)

In [15]:
torch.hub.list('pytorch/vision', force_reload=True)

Downloading: "https://github.com/pytorch/vision/archive/master.zip" to /root/.cache/torch/hub/master.zip


['alexnet',
 'deeplabv3_resnet101',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'fcn_resnet101',
 'googlenet',
 'inception_v3',
 'mobilenet_v2',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'wide_resnet101_2',
 'wide_resnet50_2']

In [16]:
pretrained_net = torch.hub.load('pytorch/vision', 'resnet34', pretrained=True)
num_classes = len(classes)

class FCN(nn.Module):
    def __init__(self, num_classes):
        super(FCN, self).__init__()

        self.stage1 = nn.Sequential(*list(pretrained_net.children())[:-4])
        self.stage2 = list(pretrained_net.children())[-4]
        self.stage3 = list(pretrained_net.children())[-3]
        
        self.scores1 = nn.Conv2d(512, num_classes, 1)
        self.scores2 = nn.Conv2d(256, num_classes, 1)
        self.scores3 = nn.Conv2d(128, num_classes, 1)
        
        self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, 4, bias=False)
        
        self.upsample_4x = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, 1, bias=False)
        
        self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, 1, bias=False)

        
    def forward(self, x):
        x = self.stage1(x)
        s1 = x # 1/8
        
        x = self.stage2(x)
        s2 = x # 1/16
        
        x = self.stage3(x)
        s3 = x # 1/32
        
        s3 = self.scores1(s3)
        s3 = self.upsample_2x(s3)
        s2 = self.scores2(s2)
        s2 = s2 + s3
        
        s1 = self.scores3(s1)
        s2 = self.upsample_4x(s2)
        s = s1 + s2

        s = self.upsample_8x(s2)
        return s

Using cache found in /root/.cache/torch/hub/pytorch_vision_master
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/checkpoints/resnet34-333f7ec4.pth
100%|██████████| 83.3M/83.3M [00:01<00:00, 57.1MB/s]


In [20]:
model = FCN(num_classes)
model.cuda()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [18]:
print(model)

FCN(
  (stage1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    

In [None]:
for epoch in range(100):
    model.train()
    
    running_loss = 0.0
    
    for img, gt in train_data:
        img = img.cuda()
        gt = gt.cuda()
        out = model(img)
        loss = loss_fn(out, gt)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    print(f'epoch: {epoch}/100, loss: {running_loss}')
        