In [1]:
%matplotlib inline

# Fully Convolutional Networks for Semantic Segmentation

## Variables

In [2]:
import torch

BATCH_SIZE = 1
NUM_EPOCHS = 50
NUM_CLASSES = 3173
USE_CUDA = torch.cuda.is_available()

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

## Utils

In [4]:
def accuracy(batch_data, pred):
    (imgs, segs, infos) = batch_data
    _, preds = torch.max(pred.data.cpu(), dim=1)
    valid = (segs >= 0)
    acc = 1.0 * torch.sum(valid * (preds == segs)) / (torch.sum(valid) + 1e-10)
    return acc, torch.sum(valid)

## Dataset

In [5]:
import os
import random
import numpy as np
import torch
import torch.utils.data as torchdata
from torchvision import transforms
from scipy.misc import imread, imresize

class Dataset(torchdata.Dataset):
    def __init__(self, root_folder, image_list, max_sample=-1, is_train=1):
        self.root_img = './{}/dataset/images'.format(root_folder)
        self.root_seg = './{}/dataset/annotations'.format(root_folder)
        self.imgSize = 100
        self.segSize = 100
        self.is_train = is_train

        # mean and std
        self.img_transform = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])

        self.list_sample = [x.rstrip() for x in open(image_list, 'r')]

        if self.is_train:
            random.shuffle(self.list_sample)
        if max_sample > 0:
            self.list_sample = self.list_sample[0:max_sample]
        num_sample = len(self.list_sample)
        assert num_sample > 0
        print('# samples: {}'.format(num_sample))

    def _scale_and_crop(self, img, seg, cropSize, is_train):
        h, w = img.shape[0], img.shape[1]

        if is_train:
            # random scale
            scale = random.random() + 0.5     # 0.5-1.5
            scale = max(scale, 1. * cropSize / (min(h, w) - 1))
        else:
            # scale to crop size
            scale = 1. * cropSize / (min(h, w) - 1)

        img_scale = imresize(img, scale, interp='bilinear')
        seg_scale = imresize(seg, scale, interp='nearest')

        h_s, w_s = img_scale.shape[0], img_scale.shape[1]
        if is_train:
            # random crop
            x1 = random.randint(0, w_s - cropSize)
            y1 = random.randint(0, h_s - cropSize)
        else:
            # center crop
            x1 = (w_s - cropSize) // 2
            y1 = (h_s - cropSize) // 2

        img_crop = img_scale[y1: y1 + cropSize, x1: x1 + cropSize, :]
        seg_crop = seg_scale[y1: y1 + cropSize, x1: x1 + cropSize]
        return img_crop, seg_crop

    def _flip(self, img, seg):
        img_flip = img[:, ::-1, :]
        seg_flip = seg[:, ::-1]
        return img_flip, seg_flip

    def __getitem__(self, index):
        img_basename = self.list_sample[index]
        path_img = os.path.join(self.root_img, img_basename)
        path_seg = os.path.join(self.root_seg,
                                img_basename.replace('.jpg', '.png'))

        assert os.path.exists(path_img), '[{}] does not exist'.format(path_img)
        assert os.path.exists(path_seg), '[{}] does not exist'.format(path_seg)

        # load image and label
        try:
            img = imread(path_img, mode='RGB')
            seg = imread(path_seg)
            assert(img.ndim == 3)
            assert(seg.ndim == 2)
            assert(img.shape[0] == seg.shape[0])
            assert(img.shape[1] == seg.shape[1])

            # random scale, crop, flip
            if self.imgSize > 0:
                img, seg = self._scale_and_crop(img, seg,
                                                self.imgSize, self.is_train)
                if random.choice([-1, 1]) > 0:
                    img, seg = self._flip(img, seg)

            # image to float
            img = img.astype(np.float32) / 255.
            img = img.transpose((2, 0, 1))

            if self.segSize > 0:
                seg = imresize(seg, (self.segSize, self.segSize),
                               interp='nearest')

            # label to int from -1 to 149
            seg = seg.astype(np.int) - 1

            # to torch tensor
            image = torch.from_numpy(img)
            segmentation = torch.from_numpy(seg)
        except Exception as e:
            print('Failed loading image/segmentation [{}]: {}'
                  .format(path_img, e))
            # dummy data
            image = torch.zeros(3, self.imgSize, self.imgSize)
            segmentation = -1 * torch.ones(self.segSize, self.segSize).long()
            return image, segmentation, img_basename

        # substracted by mean and divided by std
        image = self.img_transform(image)

        return image, segmentation, img_basename

    def __len__(self):
        return len(self.list_sample)



In [6]:
# Dataset and Loader
def load_dataset(train_list, val_list, root_folder):
    dataset_train = Dataset(root_folder, '{}/{}'.format(root_folder, train_list), is_train=1)
    dataset_val = Dataset(root_folder, '{}/{}'.format(root_folder, val_list), is_train=0)
    loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=1,
        drop_last=True)
    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=1,
        drop_last=True) 
    return loader_train, loader_val

## Train

In [7]:
def train_step(batch_data, model, optimizer, criterion, train=True, print_accuracy=False):
    (imgs, segs, infos) = batch_data

    # feed input data
    input_img = Variable(imgs, volatile=not is_train).cuda() if USE_CUDA else Variable(imgs, volatile=not train)
    label_seg = Variable(segs, volatile=not is_train).cuda() if USE_CUDA else Variable(segs, volatile=not train)

    # forward pass
    pred = model(input_img)
    err = criterion(pred, label_seg)

    # Backward
    if train:
        err.backward()
        optimizer.step()
    
    if print_accuracy:
        print('Accuracy: {}'.format(accuracy(batch_data, pred)))
    
    return err.data[0]
   
def train(train_loader, val_loader, model, optimizer, criterion):
    print('start training')
    train_losses = []
    val_losses = []
    for epoch in range(NUM_EPOCHS):
        print('epoch')
        timestamp1 = time.time()
        training_loss = 0
        val_loss = 0
        for index, batch_data in enumerate(train_loader):
            print('step')
            print_accuracy = False
            if index % 100:
                print_accuracy = True
            training_loss += train_step(batch_data, model, optimizer, criterion, train=True, print_accuracy = True)
        for index, batch_data in enumerate(val_loader):
            val_loss += train_step(batch_data, model, optimizer, criterion, train=False)
        timestamp2 = time.time()
        train_losses.append(training_loss)
        val_losses.append(val_loss)
        print('\nEpoch: {} | TRAINING Loss: {} | TESTING Loss: {} | Time: {}\n'.format(
            epoch_num + 1, training_loss, val_loss, timestamp2 - timestamp1))
        return train_losses, val_losses

## Model

In [8]:
class VGG_FCN_8(nn.Module):

    def __init__(self):
        super(VGG_FCN_8, self).__init__()
        self.features = nn.Sequential(*list(torchvision.models.vgg16(pretrained=True).features))
        self.upsampler_x32_sequence = nn.Sequential(nn.Conv2d(512, 4096, kernel_size=7),
                                                    nn.ReLU(True),
                                                    nn.Dropout(),
                                                    nn.Conv2d(4096, 4096, kernel_size=1),
                                                    nn.ReLU(True),
                                                    nn.Dropout(),
                                                    nn.Conv2d(4096, NUM_CLASSES, kernel_size=1)
                                                   )
        self.upsampler_x32 = nn.ConvTranspose2d(NUM_CLASSES, NUM_CLASSES, kernel_size=4, stride=2, bias=False)
        self.upsampler_x16_sequence = nn.Sequential(nn.Conv2d(512, NUM_CLASSES, kernel_size=1))
        self.upsampler_x16 = nn.ConvTranspose2d(NUM_CLASSES, NUM_CLASSES,  kernel_size=4, stride=2, bias=False)                                    
        self.upsampler_x8_sequence = nn.Sequential(nn.Conv2d(256, NUM_CLASSES, kernel_size=1))
        self.upsampler_x8 = nn.ConvTranspose2d(NUM_CLASSES, NUM_CLASSES, kernel_size=16, stride=8, bias=False)
        
        self.features[0].padding = (100, 100)
        
    def forward(self, x):
        x_size = x.size()
        output = x
        for i in range(17):
            output = self.features[i](output)
        upsample_x8 = self.upsampler_x8_sequence(0.0001 * output)
        for i in range(17, 24):
            output = self.features[i](output)
        upsample_x16 = self.upsampler_x16_sequence(0.01 * output)
        for i in range(24, 31):
            output = self.features[i](output)
        upsample_x32 = self.upsampler_x32_sequence(output)
        upscore_x32 = self.upsampler_x32(upsample_x32)
        upscore_x16 = self.upsampler_x16(upsample_x16[:, :, 5: (5 + upscore_x32.size()[2]), 5: (5 + upscore_x32.size()[3])] + upscore_x32) 
        upscore_x8 = self.upsampler_x8(upsample_x8[:, :, 9: (9 + upscore_x16.size()[2]), 9: (9 + upscore_x16.size()[3])] + upscore_x16)
        upscore_x8 = upscore_x8[:, :, 31: (31 + x_size[2]), 31: (31 + x_size[3])].contiguous()
        return upscore_x8

In [None]:
def initialize_model():
    model_conv = VGG_FCN_8()
    # freeze vgg params
    for param in model_conv.features.parameters():
        param.requires_grad = False
    # Parameters of newly constructed modules have requires_grad=True by default
    if USE_CUDA:
        model_conv = model_conv.cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer_conv = torch.optim.SGD(filter(lambda p: p.requires_grad,  model_conv.parameters()), lr=1e-3, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
    return model_conv, optimizer_conv, criterion

## Main

In [None]:
train_list = 'train.txt'
val_list = 'val.txt'
root_folder = 'data'
loader_train, loader_val = load_dataset(train_list, val_list, root_folder)
print('Dataset loaded.')
model, optimizer, criterion = initialize_model()
print('Model initialized, starting training...\n')
train(loader_train, loader_val, model, optimizer, criterion)
print('Training has ended.')

# samples: 20210
# samples: 2000
Dataset loaded.
Model initialized, starting training...

start training
epoch


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


step
