Skip to content

Commit

Permalink
Modularise code
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jun 27, 2017
1 parent 4194e99 commit 93341fb
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 141 deletions.
2 changes: 1 addition & 1 deletion README.md
@@ -1,7 +1,7 @@
FCN-semantic-segmentation
=========================

Simple end-to-end semantic segmentation using fully convolutional networks [[1]](#references). Takes a pretrained 34-layer ResNet [[2]](#references), removes the fully connected layers, and adds transposed convolution layers with skip residual connections from lower layers. Initialises upsampling convolutions with bilinear interpolation filters and zeros the final (classification) layer. Uses an independent cross-entropy loss per class.
Simple end-to-end semantic segmentation using fully convolutional networks [[1]](#references). Takes a pretrained 34-layer ResNet [[2]](#references), removes the fully connected layers, and adds transposed convolution layers with skip connections from lower layers. Initialises upsampling convolutions with bilinear interpolation filters and zeros the final (classification) layer. Uses an independent cross-entropy loss per class.

Calculates and plots class-wise and mean intersection-over-union. Checkpoints the network every epoch.

Expand Down
63 changes: 63 additions & 0 deletions data.py
@@ -0,0 +1,63 @@
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset


# Labels: -1 license plate, 0 unlabeled, 1 ego vehicle, 2 rectification border, 3 out of roi, 4 static, 5 dynamic, 6 ground, 7 road, 8 sidewalk, 9 parking, 10 rail track, 11 building, 12 wall, 13 fence, 14 guard rail, 15 bridge, 16 tunnel, 17 pole, 18 polegroup, 19 traffic light, 20 traffic sign, 21 vegetation, 22 terrain, 23 sky, 24 person, 25 rider, 26 car, 27 truck, 28 bus, 29 caravan, 30 trailer, 31 train, 32 motorcycle, 33 bicycle
num_classes = 20
full_to_train = {-1: 19, 0: 19, 1: 19, 2: 19, 3: 19, 4: 19, 5: 19, 6: 19, 7: 0, 8: 1, 9: 19, 10: 19, 11: 2, 12: 3, 13: 4, 14: 19, 15: 19, 16: 19, 17: 5, 18: 19, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: 19, 30: 19, 31: 16, 32: 17, 33: 18}
train_to_full = {0: 7, 1: 8, 2: 11, 3: 12, 4: 13, 5: 17, 6: 19, 7: 20, 8: 21, 9: 22, 10: 23, 11: 24, 12: 25, 13: 26, 14: 27, 15: 28, 16: 31, 17: 32, 18: 33, 19: 0}
full_to_colour = {0: (0, 0, 0), 7: (128, 64, 128), 8: (244, 35, 232), 11: (70, 70, 70), 12: (102, 102, 156), 13: (190, 153, 153), 17: (153, 153, 153), 19: (250, 170, 30), 20: (220, 220, 0), 21: (107, 142, 35), 22: (152, 251, 152), 23: (70, 130, 180), 24: (220, 20, 60), 25: (255, 0, 0), 26: (0, 0, 142), 27: (0, 0, 70), 28: (0, 60,100), 31: (0, 80, 100), 32: (0, 0, 230), 33: (119, 11, 32)}


class CityscapesDataset(Dataset):
def __init__(self, split='train', crop=None, flip=False):
super().__init__()
self.crop = crop
self.flip = flip
self.inputs = []
self.targets = []

for root, _, filenames in os.walk(os.path.join('leftImg8bit_trainvaltest', 'leftImg8bit', split)):
for filename in filenames:
if os.path.splitext(filename)[1] == '.png':
filename_base = '_'.join(filename.split('_')[:-1])
target_root = os.path.join('gtFine_trainvaltest', 'gtFine', split, os.path.basename(root))
self.inputs.append(os.path.join(root, filename_base + '_leftImg8bit.png'))
self.targets.append(os.path.join(target_root, filename_base + '_gtFine_labelIds.png'))

def __len__(self):
return len(self.inputs)

def __getitem__(self, i):
# Load images and perform augmentations with PIL
input, target = Image.open(self.inputs[i]), Image.open(self.targets[i])
# Random uniform crop
if self.crop is not None:
w, h = input.size
x1, y1 = random.randint(0, w - self.crop), random.randint(0, h - self.crop)
input, target = input.crop((x1, y1, x1 + self.crop, y1 + self.crop)), target.crop((x1, y1, x1 + self.crop, y1 + self.crop))
# Random horizontal flip
if self.flip:
if random.random() < 0.5:
input, target = input.transpose(Image.FLIP_LEFT_RIGHT), target.transpose(Image.FLIP_LEFT_RIGHT)

# Convert to tensors
w, h = input.size
input = torch.ByteTensor(torch.ByteStorage.from_buffer(input.tobytes())).view(h, w, 3).permute(2, 0, 1).float().div(255)
target = torch.ByteTensor(torch.ByteStorage.from_buffer(target.tobytes())).view(h, w).long()
# Normalise input
input[0].add_(-0.485).div_(0.229)
input[1].add_(-0.456).div_(0.224)
input[2].add_(-0.406).div_(0.225)
# Convert to training labels
remapped_target = target.clone()
for k, v in full_to_train.items():
remapped_target[target == k] = v
# Create one-hot encoding
target = torch.zeros(num_classes, h, w)
for c in range(num_classes):
target[c][remapped_target == c] = 1
return input, target, remapped_target # Return x, y (one-hot), y (index)
147 changes: 7 additions & 140 deletions main.py
@@ -1,18 +1,18 @@
from argparse import ArgumentParser
import os
import random
from matplotlib import pyplot as plt
import torch
from PIL import Image
from torch import optim
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.models.resnet import BasicBlock, ResNet
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.utils import save_image
from matplotlib import pyplot as plt

from data import CityscapesDataset, num_classes, full_to_colour, train_to_full
from model import FeatureResNet, SegResNet


# Setup
Expand All @@ -31,146 +31,14 @@
os.makedirs('results')
plt.switch_backend('agg') # Allow plotting when running remotely

# Labels: -1 license plate, 0 unlabeled, 1 ego vehicle, 2 rectification border, 3 out of roi, 4 static, 5 dynamic, 6 ground, 7 road, 8 sidewalk, 9 parking, 10 rail track, 11 building, 12 wall, 13 fence, 14 guard rail, 15 bridge, 16 tunnel, 17 pole, 18 polegroup, 19 traffic light, 20 traffic sign, 21 vegetation, 22 terrain, 23 sky, 24 person, 25 rider, 26 car, 27 truck, 28 bus, 29 caravan, 30 trailer, 31 train, 32 motorcycle, 33 bicycle
num_classes = 20
full_to_train = {-1: 19, 0: 19, 1: 19, 2: 19, 3: 19, 4: 19, 5: 19, 6: 19, 7: 0, 8: 1, 9: 19, 10: 19, 11: 2, 12: 3, 13: 4, 14: 19, 15: 19, 16: 19, 17: 5, 18: 19, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: 19, 30: 19, 31: 16, 32: 17, 33: 18}
train_to_full = {0: 7, 1: 8, 2: 11, 3: 12, 4: 13, 5: 17, 6: 19, 7: 20, 8: 21, 9: 22, 10: 23, 11: 24, 12: 25, 13: 26, 14: 27, 15: 28, 16: 31, 17: 32, 18: 33, 19: 0}
full_to_colour = {0: (0, 0, 0), 7: (128, 64, 128), 8: (244, 35, 232), 11: (70, 70, 70), 12: (102, 102, 156), 13: (190, 153, 153), 17: (153, 153, 153), 19: (250, 170, 30), 20: (220, 220, 0), 21: (107, 142, 35), 22: (152, 251, 152), 23: (70, 130, 180), 24: (220, 20, 60), 25: (255, 0, 0), 26: (0, 0, 142), 27: (0, 0, 70), 28: (0, 60,100), 31: (0, 80, 100), 32: (0, 0, 230), 33: (119, 11, 32)}


# Data
class CityscapesDataset(Dataset):
def __init__(self, split='train', crop=None, flip=False):
super().__init__()
self.crop = crop
self.flip = flip
self.inputs = []
self.targets = []

for root, _, filenames in os.walk(os.path.join('leftImg8bit_trainvaltest', 'leftImg8bit', split)):
for filename in filenames:
if os.path.splitext(filename)[1] == '.png':
filename_base = '_'.join(filename.split('_')[:-1])
target_root = os.path.join('gtFine_trainvaltest', 'gtFine', split, os.path.basename(root))
self.inputs.append(os.path.join(root, filename_base + '_leftImg8bit.png'))
self.targets.append(os.path.join(target_root, filename_base + '_gtFine_labelIds.png'))

def __len__(self):
return len(self.inputs)

def __getitem__(self, i):
# Load images and perform augmentations with PIL
input, target = Image.open(self.inputs[i]), Image.open(self.targets[i])
# Random uniform crop
if self.crop is not None:
w, h = input.size
x1, y1 = random.randint(0, w - self.crop), random.randint(0, h - self.crop)
input, target = input.crop((x1, y1, x1 + self.crop, y1 + self.crop)), target.crop((x1, y1, x1 + self.crop, y1 + self.crop))
# Random horizontal flip
if self.flip:
if random.random() < 0.5:
input, target = input.transpose(Image.FLIP_LEFT_RIGHT), target.transpose(Image.FLIP_LEFT_RIGHT)

# Convert to tensors
w, h = input.size
input = torch.ByteTensor(torch.ByteStorage.from_buffer(input.tobytes())).view(h, w, 3).permute(2, 0, 1).float().div(255)
target = torch.ByteTensor(torch.ByteStorage.from_buffer(target.tobytes())).view(h, w).long()
# Normalise input
input[0].add_(-0.485).div_(0.229)
input[1].add_(-0.456).div_(0.224)
input[2].add_(-0.406).div_(0.225)
# Convert to training labels
remapped_target = target.clone()
for k, v in full_to_train.items():
remapped_target[target == k] = v
# Create one-hot encoding
target = torch.zeros(num_classes, h, w)
for c in range(num_classes):
target[c][remapped_target == c] = 1
return input, target, remapped_target # Return x, y (one-hot), y (index)


train_dataset = CityscapesDataset(split='train', crop=args.crop_size, flip=True)
val_dataset = CityscapesDataset(split='val')
weight_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=args.workers, pin_memory=True)


# Models
# Returns 2D convolutional layer with space-preserving padding
def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, bias=False, transposed=False):
if transposed:
layer = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1, dilation=dilation, bias=bias)
# Bilinear interpolation init
w = torch.Tensor(kernel_size, kernel_size)
centre = kernel_size % 2 == 1 and stride - 1 or stride - 0.5
for y in range(kernel_size):
for x in range(kernel_size):
w[y, x] = (1 - abs((x - centre) / stride)) * (1 - abs((y - centre) / stride))
layer.weight.data.copy_(w.div(in_planes).repeat(out_planes, in_planes, 1, 1))
else:
padding = (kernel_size + 2 * (dilation - 1)) // 2
layer = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
if bias:
init.constant(layer.bias, 0)
return layer


# Returns 2D batch normalisation layer
def bn(planes):
layer = nn.BatchNorm2d(planes)
# Use mean 0, standard deviation 1 init
init.constant(layer.weight, 1)
init.constant(layer.bias, 0)
return layer


class FeatureResNet(ResNet):
def __init__(self):
super().__init__(BasicBlock, [3, 4, 6, 3], 1000)

def forward(self, x):
x1 = self.conv1(x)
x = self.bn1(x1)
x = self.relu(x)
x2 = self.maxpool(x)
x = self.layer1(x2)
x3 = self.layer2(x)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
return x1, x2, x3, x4, x5


class SegResNet(nn.Module):
def __init__(self, num_classes, pretrained_net):
super().__init__()
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.conv5 = conv(512 + num_classes, 256, stride=2, transposed=True)
self.bn5 = bn(256)
self.conv6 = conv(256 + num_classes, 128, stride=2, transposed=True)
self.bn6 = bn(128)
self.conv7 = conv(128 + num_classes, 64, stride=2, transposed=True)
self.bn7 = bn(64)
self.conv8 = conv(64 + num_classes, 64, stride=2, transposed=True)
self.bn8 = bn(64)
self.conv9 = conv(64 + num_classes, 32, stride=2, transposed=True)
self.bn9 = bn(32)
self.conv10 = conv(32, num_classes, kernel_size=7, gain='linear')
init.constant(self.conv10.weight, 0) # Zero init

def forward(self, x):
x1, x2, x3, x4, x5 = self.pretrained_net(x)
x = self.relu(self.bn5(self.conv5(x5)))
x = self.relu(self.bn6(self.conv6(x + x4)))
x = self.relu(self.bn7(self.conv7(x + x3)))
x = self.relu(self.bn8(self.conv8(x + x2)))
x = self.relu(self.bn9(self.conv9(x + x1)))
x = self.conv10(x)
return x


# Training/Testing
pretrained_net = FeatureResNet()
pretrained_net.load_state_dict(models.resnet34(pretrained=True).state_dict())
Expand Down Expand Up @@ -256,8 +124,7 @@ def test(e):
plt.plot(es, mean_scores, 'b-')
plt.xlabel('Epoch')
plt.ylabel('Mean IoU')
plt.xlim(xmin=0, xmax=args.epochs)
plt.savefig(os.path.join('results', 'ious.eps'))
plt.savefig(os.path.join('results', 'ious.png'))
plt.close()


Expand Down
77 changes: 77 additions & 0 deletions model.py
@@ -0,0 +1,77 @@
import torch
from torch import nn
from torch.nn import init
from torchvision.models.resnet import BasicBlock, ResNet


# Returns 2D convolutional layer with space-preserving padding
def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, bias=False, transposed=False):
if transposed:
layer = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1, dilation=dilation, bias=bias)
# Bilinear interpolation init
w = torch.Tensor(kernel_size, kernel_size)
centre = kernel_size % 2 == 1 and stride - 1 or stride - 0.5
for y in range(kernel_size):
for x in range(kernel_size):
w[y, x] = (1 - abs((x - centre) / stride)) * (1 - abs((y - centre) / stride))
layer.weight.data.copy_(w.div(in_planes).repeat(out_planes, in_planes, 1, 1))
else:
padding = (kernel_size + 2 * (dilation - 1)) // 2
layer = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
if bias:
init.constant(layer.bias, 0)
return layer


# Returns 2D batch normalisation layer
def bn(planes):
layer = nn.BatchNorm2d(planes)
# Use mean 0, standard deviation 1 init
init.constant(layer.weight, 1)
init.constant(layer.bias, 0)
return layer


class FeatureResNet(ResNet):
def __init__(self):
super().__init__(BasicBlock, [3, 4, 6, 3], 1000)

def forward(self, x):
x1 = self.conv1(x)
x = self.bn1(x1)
x = self.relu(x)
x2 = self.maxpool(x)
x = self.layer1(x2)
x3 = self.layer2(x)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
return x1, x2, x3, x4, x5


class SegResNet(nn.Module):
def __init__(self, num_classes, pretrained_net):
super().__init__()
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.conv5 = conv(512, 256, stride=2, transposed=True)
self.bn5 = bn(256)
self.conv6 = conv(256, 128, stride=2, transposed=True)
self.bn6 = bn(128)
self.conv7 = conv(128, 64, stride=2, transposed=True)
self.bn7 = bn(64)
self.conv8 = conv(64, 64, stride=2, transposed=True)
self.bn8 = bn(64)
self.conv9 = conv(64, 32, stride=2, transposed=True)
self.bn9 = bn(32)
self.conv10 = conv(32, num_classes, kernel_size=7)
init.constant(self.conv10.weight, 0) # Zero init

def forward(self, x):
x1, x2, x3, x4, x5 = self.pretrained_net(x)
x = self.relu(self.bn5(self.conv5(x5)))
x = self.relu(self.bn6(self.conv6(x + x4)))
x = self.relu(self.bn7(self.conv7(x + x3)))
x = self.relu(self.bn8(self.conv8(x + x2)))
x = self.relu(self.bn9(self.conv9(x + x1)))
x = self.conv10(x)
return x

0 comments on commit 93341fb

Please sign in to comment.