In [1]:
from argparse import ArgumentParser
import os
import random
from matplotlib import pyplot as plt
import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.utils import save_image

from data import CityscapesDataset, num_classes, full_to_colour, train_to_full
from model import FeatureResNet, SegResNet

In [2]:
??FeatureResNet

In [3]:
??SegResNet

In [None]:
# dimension of cityscapes : 2048 x 1024

In [None]:
# Setup
parser = ArgumentParser(description='Semantic segmentation')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
parser.add_argument('--workers', type=int, default=8, help='Data loader workers')
parser.add_argument('--epochs', type=int, default=100, help='Training epochs')
parser.add_argument('--crop-size', type=int, default=512, help='Training crop size')
parser.add_argument('--lr', type=float, default=5e-5, help='Learning rate')
parser.add_argument('--weight-decay', type=float, default=2e-4, help='Weight decay')
parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
args = parser.parse_args()
random.seed(args.seed)
torch.manual_seed(args.seed)
if not os.path.exists('results'):
    os.makedirs('results')
plt.switch_backend('agg')  # Allow plotting when running remotely

In [38]:
# hyperparmas:
seed = 42
workers = 8
epochs = 2  # 100
crop_size = 512
lr = 5e-5
weight_decay = 2e-4
batch_size = 2

In [8]:
torch.manual_seed(seed)
if not os.path.exists('results'):
    os.makedirs('results')

In [9]:
train_dataset = CityscapesDataset(split='train', crop=cropsize, flip=True)

In [10]:
val_dataset = CityscapesDataset(split='val')

In [11]:
type(train_dataset)

data.CityscapesDataset

In [13]:
??DataLoader

In [14]:
# Data
train_dataset = CityscapesDataset(split='train', crop=crop_size, flip=True)
val_dataset = CityscapesDataset(split='val')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=workers, pin_memory=True)

In [15]:
type(train_loader)

torch.utils.data.dataloader.DataLoader

In [16]:
pretrained_net = FeatureResNet()

In [17]:
??pretrained_net

In [18]:
print(pretrained_net)

FeatureResNet (
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (relu): ReLU (inplace)
  (maxpool): MaxPool2d (size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))
  (layer1): Sequential (
    (0): BasicBlock (
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    )
    (1): BasicBlock (
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): Batc

In [19]:
pretrained_net.load_state_dict(models.resnet34(pretrained=True).state_dict())

In [26]:
type(pretrained_net)

model.FeatureResNet

In [25]:
type(SegResNet)

type

In [21]:
num_classes

20

In [23]:
net = SegResNet(num_classes, pretrained_net).cuda()

In [24]:
type(net)

model.SegResNet

In [29]:
crit = nn.BCELoss().cuda()

In [31]:
type(crit)

torch.nn.modules.loss.BCELoss

In [34]:
optimiser = optim.RMSprop(net.parameters(), lr=lr, weight_decay=weight_decay)

In [27]:
??nn.BCELoss()

In [None]:
# Training/Testing

# create the network structure for ResNet
pretrained_net = FeatureResNet()

# load the weights
pretrained_net.load_state_dict(models.resnet34(pretrained=True).state_dict())

# plug the ResNet into FCNN framework
net = SegResNet(num_classes, pretrained_net).cuda()

# why use Binary corss-entropy loss
crit = nn.BCELoss().cuda()

# compile the optimizer
optimiser = optim.RMSprop(net.parameters(), lr=lr, weight_decay=weight_decay)

# to hold the score and mean score
scores, mean_scores = [], []

In [35]:
??net.train()

Object `cuda` not found.


In [37]:
# traing steps
# e = epoch
def train(e):
    # only useful for dropout / batchnorm
    net.train()
    #  loop over the data iterator, and feed the inputs to the network and optimize
    for i, (input, target, _) in enumerate(train_loader):
        # reset grad
        optimiser.zero_grad()
        # use GPU CUDA
        input, target = Variable(input.cuda(async=True)), Variable(target.cuda(async=True))
        
        # why use sigmoid function?
        output = F.sigmoid(net(input))
        loss = crit(output, target)
        print(e, i, loss.data[0])
        # compute gradients
        loss.backward()
        # update params with gradients
        optimiser.step()

In [40]:
# Calculates class intersections over unions
def iou(pred, target):
    ious = []
    # Ignore IoU for background class, num_classes = 20
    for cls in range(num_classes - 1):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds[target_inds]).long().sum().data.cpu()[0]  # Cast to long to prevent overflows
        union = pred_inds.long().sum().data.cpu()[0] + target_inds.long().sum().data.cpu()[0] - intersection
        if union == 0:
            ious.append(float('nan'))  # If there is no ground truth, do not include in evaluation
        else:
            ious.append(intersection / max(union, 1))
    return ious

In [41]:
def test(e):
    net.eval()
    total_ious = []
    for i, (input, _, target) in enumerate(val_loader):
        input, target = Variable(input.cuda(async=True), volatile=True), Variable(target.cuda(async=True), volatile=True)
        output = F.log_softmax(net(input))
        b, _, h, w = output.size()
        pred = output.permute(0, 2, 3, 1).contiguous().view(-1, num_classes).max(1)[1].view(b, h, w)
        total_ious.append(iou(pred, target))
    
        # Save images
        if i % 25 == 0:
            pred = pred.data.cpu()
            pred_remapped = pred.clone()
            # Convert to full labels
            for k, v in train_to_full.items():
                pred_remapped[pred == k] = v
            # Convert to colour image
            pred = pred_remapped
            pred_colour = torch.zeros(b, 3, h, w)
            for k, v in full_to_colour.items():
                pred_r = torch.zeros(b, 1, h, w)
                pred_r[(pred == k)] = v[0]
                pred_g = torch.zeros(b, 1, h, w)
                pred_g[(pred == k)] = v[1]
                pred_b = torch.zeros(b, 1, h, w)
                pred_b[(pred == k)] = v[2]
                pred_colour.add_(torch.cat((pred_r, pred_g, pred_b), 1))
            save_image(pred_colour[0].float().div(255), os.path.join('results', str(e) + '_' + str(i) + '.png'))
  
    # Calculate average IoU
    total_ious = torch.Tensor(total_ious).transpose(0, 1)
    ious = torch.Tensor(num_classes - 1)
    for i, class_iou in enumerate(total_ious):
        ious[i] = class_iou[class_iou == class_iou].mean()  # Calculate mean, ignoring NaNs
    print(ious, ious.mean())
    scores.append(ious)
  
    # Save weights and scores
    torch.save(net.state_dict(), os.path.join('results', str(e) + '_net.pth'))
    torch.save(scores, os.path.join('results', 'scores.pth'))
  
    # Plot scores
    mean_scores.append(ious.mean())
    es = list(range(len(mean_scores)))
    plt.plot(es, mean_scores, 'b-')
    plt.xlabel('Epoch')
    plt.ylabel('Mean IoU')
    plt.savefig(os.path.join('results', 'ious.png'))
    plt.close()

In [42]:
epochs

2

In [None]:
for e in range(1, epochs + 1):
    train(e)
    test(e)

1 0 0.6931458711624146
1 1 0.6355560421943665
1 2 0.6427205801010132
1 3 0.6433953046798706
1 4 0.6305932998657227
1 5 0.6067423820495605
1 6 0.5945364236831665
1 7 0.5863555669784546
1 8 0.5912925004959106
1 9 0.5894274711608887
1 10 0.5746929049491882
1 11 0.5672615766525269
1 12 0.5707300901412964
1 13 0.5112751126289368
1 14 0.48332276940345764
1 15 0.5117632150650024
1 16 0.47450026869773865
1 17 0.463645875453949
1 18 0.3759079873561859
1 19 0.39688995480537415
1 20 0.37869444489479065
1 21 0.34906309843063354
1 22 0.3234705626964569
1 23 0.3480839133262634
1 24 0.28830486536026
1 25 0.2619107961654663
1 26 0.2768684923648834
1 27 0.24320001900196075
1 28 0.24300476908683777
1 29 0.23459315299987793
1 30 0.22209849953651428
1 31 0.20149092376232147
1 32 0.23562505841255188
1 33 0.2050512284040451
1 34 0.2157030999660492
1 35 0.20715849101543427
1 36 0.22229750454425812
1 37 0.23299284279346466
1 38 0.1910499930381775
1 39 0.19952347874641418
1 40 0.18842554092407227
1 41 0.201897

1 324 0.16952621936798096
1 325 0.14396785199642181
1 326 0.11784398555755615
1 327 0.12917423248291016
1 328 0.18002407252788544
1 329 0.1367807239294052
1 330 0.12215503305196762
1 331 0.13240770995616913
1 332 0.18047718703746796
1 333 0.16954916715621948
1 334 0.15677234530448914
1 335 0.1712571531534195
1 336 0.15397094190120697
1 337 0.15282964706420898
1 338 0.12067949771881104
1 339 0.14261645078659058
1 340 0.13222673535346985
1 341 0.12911447882652283
1 342 0.14942654967308044
1 343 0.1129477247595787
1 344 0.15369239449501038
1 345 0.1069754958152771
1 346 0.19257931411266327
1 347 0.11606863886117935
1 348 0.10825903713703156
1 349 0.16675244271755219
1 350 0.12300588935613632
1 351 0.16172607243061066
1 352 0.1606787145137787
1 353 0.1218864917755127
1 354 0.11406876891851425
1 355 0.1262023150920868
1 356 0.12387244403362274
1 357 0.17317430675029755
1 358 0.1475595384836197
1 359 0.13853588700294495
1 360 0.1437016725540161
1 361 0.20906205475330353
1 362 0.1647128313779

1 643 0.1600598394870758
1 644 0.12066410481929779
1 645 0.18574878573417664
1 646 0.1279158592224121
1 647 0.1750504970550537
1 648 0.14038851857185364
1 649 0.15080709755420685
1 650 0.16173292696475983
1 651 0.09766177088022232
1 652 0.16784359514713287
1 653 0.13530613481998444
1 654 0.14923451840877533
1 655 0.14028196036815643
1 656 0.13676327466964722
1 657 0.15362763404846191
1 658 0.09165014326572418
1 659 0.12035313993692398
1 660 0.14953011274337769
1 661 0.10439883172512054
1 662 0.14715030789375305
1 663 0.11167421191930771
1 664 0.14883002638816833
1 665 0.13839280605316162
1 666 0.14774355292320251
1 667 0.12894800305366516
1 668 0.11570147424936295
1 669 0.14993873238563538
1 670 0.11998595297336578
1 671 0.12365175783634186
1 672 0.1537283957004547
1 673 0.148948535323143
1 674 0.1502455770969391
1 675 0.14124919474124908
1 676 0.1090245246887207
1 677 0.14162799715995789
1 678 0.11744359880685806
1 679 0.177715465426445
1 680 0.14492806792259216
1 681 0.19054862856864

1 962 0.13902008533477783
1 963 0.16731010377407074
1 964 0.17439484596252441
1 965 0.1561526358127594
1 966 0.09468428045511246
1 967 0.1446932703256607
1 968 0.13221439719200134
1 969 0.14815539121627808
1 970 0.17783118784427643
1 971 0.09559211879968643
1 972 0.1744704246520996
1 973 0.12635818123817444
1 974 0.10200254619121552
1 975 0.1254597008228302
1 976 0.11359286308288574
1 977 0.1358216106891632
1 978 0.14164802432060242
1 979 0.181075319647789
1 980 0.195900097489357
1 981 0.14487335085868835
1 982 0.09517328441143036
1 983 0.13585111498832703
1 984 0.17016127705574036
1 985 0.17779752612113953
1 986 0.15345928072929382
1 987 0.13365010917186737
1 988 0.1279253363609314
1 989 0.2053709328174591
1 990 0.12474668025970459
1 991 0.15621192753314972
1 992 0.13031716644763947
1 993 0.1544943004846573
1 994 0.13229431211948395
1 995 0.17253969609737396
1 996 0.1359737664461136
1 997 0.09505583345890045
1 998 0.16339127719402313
1 999 0.148445725440979
1 1000 0.11404011398553848


1 1270 0.1512284129858017
1 1271 0.14318907260894775
1 1272 0.1095210462808609
1 1273 0.1776197850704193
1 1274 0.1407162845134735
1 1275 0.1674683839082718
1 1276 0.13640102744102478
1 1277 0.15084335207939148
1 1278 0.10057435184717178
1 1279 0.11092312633991241
1 1280 0.20158538222312927
1 1281 0.10035575926303864
1 1282 0.11627276986837387
1 1283 0.0949200838804245
1 1284 0.20306377112865448
1 1285 0.1396486461162567
1 1286 0.18580171465873718
1 1287 0.1208876371383667
1 1288 0.12712427973747253
1 1289 0.1471814513206482
1 1290 0.13176774978637695
1 1291 0.1371045708656311
1 1292 0.18106260895729065
1 1293 0.13761486113071442
1 1294 0.1198289543390274
1 1295 0.13549558818340302
1 1296 0.17132753133773804
1 1297 0.09824708849191666
1 1298 0.1514371633529663
1 1299 0.13074582815170288
1 1300 0.08754850924015045
1 1301 0.16991600394248962
1 1302 0.14063255488872528
1 1303 0.14310656487941742
1 1304 0.11840786039829254
1 1305 0.08974780142307281
1 1306 0.1580015867948532
1 1307 0.14879