In [2]:
%matplotlib inline
import os
import collections
import torch
import torchvision
import numpy as np
import scipy.misc as m
import scipy.io as io
import matplotlib.pyplot as plt

from NYUDv2Loader import *

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
data_path = '/home/dongwonshin/Desktop/Datasets/NYUDv2/'
arg_string = '--arch fcn8s --batch_size 3'

# Argument setting

In [None]:
import sys, os
import torch
import visdom
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

from torch.autograd import Variable
from torch.utils import data
from tqdm import tqdm

from ptsemseg.models import get_model
from ptsemseg.loader import get_loader, get_data_path
from ptsemseg.metrics import runningScore
from ptsemseg.loss import *
from ptsemseg.augmentations import *


parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--arch', nargs='?', type=str, default='fcn8s', help='Architecture to use [\'fcn8s, unet, segnet etc\']')
parser.add_argument('--img_rows', nargs='?', type=int, default=256, help='Height of the input image')
parser.add_argument('--img_cols', nargs='?', type=int, default=256, help='Width of the input image')

parser.add_argument('--img_norm', dest='img_norm', action='store_true', help='Enable input image scales normalization [0, 1] | True by default')
parser.add_argument('--no-img_norm', dest='img_norm', action='store_false', help='Disable input image scales normalization [0, 1] | True by default')
parser.set_defaults(img_norm=True)

parser.add_argument('--n_epoch', nargs='?', type=int, default=100, help='# of the epochs')
parser.add_argument('--batch_size', nargs='?', type=int, default=1, help='Batch Size')
parser.add_argument('--l_rate', nargs='?', type=float, default=1e-5, help='Learning Rate')
parser.add_argument('--feature_scale', nargs='?', type=int, default=1, help='Divider for # of features to use')
parser.add_argument('--resume', nargs='?', type=str, default=None, help='Path to previous saved model to restart from')

parser.add_argument('--visdom', dest='visdom', action='store_true', help='Enable visualization(s) on visdom | False by default')
parser.add_argument('--no-visdom', dest='visdom', action='store_false', help='Disable visualization(s) on visdom | False by default')
parser.set_defaults(visdom=False)

# Dataset 준비 & model construction

In [None]:
args = parser.parse_args(arg_string.split(' '))

# Setup Augmentations
data_aug= Compose([RandomRotate(10), RandomHorizontallyFlip()])

# Setup Dataloader
t_loader = NYUDv2Loader(data_path, is_transform=True)
v_loader = NYUDv2Loader(data_path, is_transform=True, split='val')

n_classes = t_loader.n_classes
trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=16, shuffle=True)
valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=16)

# Setup Metrics
running_metrics = runningScore(n_classes)

# Setup visdom for visualization
if args.visdom:
    vis = visdom.Visdom()

    loss_window = vis.line(X=torch.zeros((1,)).cpu(),
                       Y=torch.zeros((1)).cpu(),
                       opts=dict(xlabel='minibatches',
                                 ylabel='Loss',
                                 title='Training Loss',
                                 legend=['Loss']))

# Setup Model
model = get_model(args.arch, n_classes)

model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
model.cuda()

# Check if model has custom optimizer / loss
if hasattr(model.module, 'optimizer'):
    optimizer = model.module.optimizer
else:
    optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)

if hasattr(model.module, 'loss'):
    print('Using custom loss')
    loss_fn = model.module.loss
else:
    loss_fn = cross_entropy2d

if args.resume is not None:                                         
    if os.path.isfile(args.resume):
        print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        print("Loaded checkpoint '{}' (epoch {})"                    
              .format(args.resume, checkpoint['epoch']))
    else:
        print("No checkpoint found at '{}'".format(args.resume)) 


# Training

In [None]:
best_iou = -100.0 
for epoch in range(args.n_epoch):
    model.train()
    for i, (color_imgs, depth_imgs, label_imgs) in enumerate(trainloader):
        images = Variable(color_imgs.cuda())
        depth_images = Variable(depth_imgs.cuda())
        labels = Variable(label_imgs.cuda())

        optimizer.zero_grad()
        outputs = model(images)

        loss = loss_fn(input=outputs, target=labels)

        loss.backward()
        optimizer.step()

        if args.visdom:
            vis.line(
                X=torch.ones((1, 1)).cpu() * i,
                Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                win=loss_window,
                update='append')

        if (i+1) % 20 == 0:
            print("Epoch [%d/%d] Loss: %.4f" % (epoch+1, args.n_epoch, loss.data[0]))

    model.eval()
    for i_val, (color_images_val, depth_images_val, label_images_val) in tqdm(enumerate(valloader)):
        color_images_val = Variable(color_images_val.cuda(), volatile=True)
        depth_images_val = Variable(depth_images_val.cuda(), volatile=True)
        label_images_val = Variable(label_images_val.cuda(), volatile=True)

        outputs = model(color_images_val)
        pred = outputs.data.max(1)[1].cpu().numpy()
        gt = label_images_val.data.cpu().numpy()
        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()
    for k, v in score.items():
        print(k, v)
    running_metrics.reset()

    if score['Mean IoU : \t'] >= best_iou:
        best_iou = score['Mean IoU : \t']
        state = {'epoch': epoch+1,
                 'model_state': model.state_dict(),
                 'optimizer_state' : optimizer.state_dict(),}
        torch.save(state, "{}_{}_best_model.pkl".format(args.arch, 'NYUDv2'))



Epoch [1/100] Loss: 2.8223
Epoch [1/100] Loss: 2.7707
Epoch [1/100] Loss: 2.2866
Epoch [1/100] Loss: 2.6855
Epoch [1/100] Loss: 2.2286
Epoch [1/100] Loss: 2.2721
Epoch [1/100] Loss: 2.0192
Epoch [1/100] Loss: 2.0519
Epoch [1/100] Loss: 2.1019
Epoch [1/100] Loss: 1.9484
Epoch [1/100] Loss: 1.7170
Epoch [1/100] Loss: 2.1335
Epoch [1/100] Loss: 2.0073


218it [01:09,  3.12it/s]


Overall Acc: 	 0.34535749554
Mean Acc : 	 0.148755437557
FreqW Acc : 	 0.175323779791
Mean IoU : 	 0.0765410621489
Epoch [2/100] Loss: 2.0023
Epoch [2/100] Loss: 2.2467
Epoch [2/100] Loss: 1.7664
Epoch [2/100] Loss: 2.0660
Epoch [2/100] Loss: 1.8743
Epoch [2/100] Loss: 2.1356
Epoch [2/100] Loss: 1.8393
Epoch [2/100] Loss: 1.8988
Epoch [2/100] Loss: 1.8578
Epoch [2/100] Loss: 1.8351
Epoch [2/100] Loss: 2.0277
Epoch [2/100] Loss: 1.7773
Epoch [2/100] Loss: 1.8238


218it [01:10,  3.09it/s]


Overall Acc: 	 0.38472464123
Mean Acc : 	 0.170219727469
FreqW Acc : 	 0.217634482429
Mean IoU : 	 0.0991177622536
Epoch [3/100] Loss: 1.7351
Epoch [3/100] Loss: 1.7483
Epoch [3/100] Loss: 2.0667
Epoch [3/100] Loss: 1.6060
Epoch [3/100] Loss: 1.7909
Epoch [3/100] Loss: 1.6755
Epoch [3/100] Loss: 1.7433
Epoch [3/100] Loss: 1.7195
Epoch [3/100] Loss: 1.7841
Epoch [3/100] Loss: 1.8623
Epoch [3/100] Loss: 1.8680
Epoch [3/100] Loss: 1.7580
Epoch [3/100] Loss: 1.7716


218it [01:10,  3.07it/s]


Overall Acc: 	 0.41972901635
Mean Acc : 	 0.205904484406
FreqW Acc : 	 0.242796405842
Mean IoU : 	 0.121682437797
Epoch [4/100] Loss: 1.6742
Epoch [4/100] Loss: 1.5536
Epoch [4/100] Loss: 1.4812
Epoch [4/100] Loss: 1.5132
Epoch [4/100] Loss: 1.6414
Epoch [4/100] Loss: 1.6000
Epoch [4/100] Loss: 2.0642
Epoch [4/100] Loss: 1.6660
Epoch [4/100] Loss: 1.5205
Epoch [4/100] Loss: 1.4977
Epoch [4/100] Loss: 1.7578
Epoch [4/100] Loss: 1.6732
Epoch [4/100] Loss: 1.9923


218it [01:11,  3.05it/s]


Overall Acc: 	 0.443160249825
Mean Acc : 	 0.235922690259
FreqW Acc : 	 0.267442665654
Mean IoU : 	 0.145968057772
Epoch [5/100] Loss: 1.6311
Epoch [5/100] Loss: 1.7194
Epoch [5/100] Loss: 1.4746
Epoch [5/100] Loss: 1.5417
Epoch [5/100] Loss: 1.2758
Epoch [5/100] Loss: 1.3241
Epoch [5/100] Loss: 1.3959
Epoch [5/100] Loss: 1.3835
Epoch [5/100] Loss: 1.5792
Epoch [5/100] Loss: 1.5264
Epoch [5/100] Loss: 1.4831
Epoch [5/100] Loss: 1.6235
Epoch [5/100] Loss: 1.7546


218it [01:11,  3.04it/s]


Overall Acc: 	 0.453301094825
Mean Acc : 	 0.247615944633
FreqW Acc : 	 0.272060357504
Mean IoU : 	 0.151805814289
Epoch [6/100] Loss: 1.8825
Epoch [6/100] Loss: 1.6359
Epoch [6/100] Loss: 1.9251
Epoch [6/100] Loss: 1.4764
Epoch [6/100] Loss: 1.4282
Epoch [6/100] Loss: 1.7852
Epoch [6/100] Loss: 1.3910
Epoch [6/100] Loss: 1.6879
Epoch [6/100] Loss: 1.5521
Epoch [6/100] Loss: 1.3879
Epoch [6/100] Loss: 1.7041
Epoch [6/100] Loss: 1.6501
Epoch [6/100] Loss: 1.4465


218it [01:11,  3.05it/s]


Overall Acc: 	 0.468737128488
Mean Acc : 	 0.282656711922
FreqW Acc : 	 0.292841938961
Mean IoU : 	 0.181479617416
Epoch [7/100] Loss: 1.4750
Epoch [7/100] Loss: 1.5449
Epoch [7/100] Loss: 1.3802
Epoch [7/100] Loss: 1.5064
Epoch [7/100] Loss: 1.6115
Epoch [7/100] Loss: 1.7490
Epoch [7/100] Loss: 1.6469
Epoch [7/100] Loss: 1.6533
Epoch [7/100] Loss: 1.7591
Epoch [7/100] Loss: 1.5488
Epoch [7/100] Loss: 1.4800
Epoch [7/100] Loss: 1.5863
Epoch [7/100] Loss: 1.5623


218it [01:11,  3.04it/s]


Overall Acc: 	 0.485094301494
Mean Acc : 	 0.302997523887
FreqW Acc : 	 0.309683145282
Mean IoU : 	 0.20004357992
Epoch [8/100] Loss: 1.6307
Epoch [8/100] Loss: 1.2992
Epoch [8/100] Loss: 1.4320
Epoch [8/100] Loss: 1.6055
Epoch [8/100] Loss: 1.4078
Epoch [8/100] Loss: 1.5726
Epoch [8/100] Loss: 1.6345
Epoch [8/100] Loss: 1.2700
Epoch [8/100] Loss: 1.4608
Epoch [8/100] Loss: 1.3781
Epoch [8/100] Loss: 1.5806
Epoch [8/100] Loss: 1.5398
Epoch [8/100] Loss: 1.6307


218it [01:11,  3.04it/s]


Overall Acc: 	 0.499608220247
Mean Acc : 	 0.321107518702
FreqW Acc : 	 0.32076341504
Mean IoU : 	 0.21526069011
Epoch [9/100] Loss: 1.2524
Epoch [9/100] Loss: 1.3671
Epoch [9/100] Loss: 1.2102
Epoch [9/100] Loss: 1.3590
Epoch [9/100] Loss: 1.6888
Epoch [9/100] Loss: 2.0692
Epoch [9/100] Loss: 1.7319
Epoch [9/100] Loss: 1.4999
Epoch [9/100] Loss: 1.0859
Epoch [9/100] Loss: 1.5251
Epoch [9/100] Loss: 1.3614
Epoch [9/100] Loss: 1.1834
Epoch [9/100] Loss: 1.2080


218it [01:11,  3.07it/s]


Overall Acc: 	 0.504011626171
Mean Acc : 	 0.325045150277
FreqW Acc : 	 0.330761981546
Mean IoU : 	 0.221418087976
Epoch [10/100] Loss: 1.4069
Epoch [10/100] Loss: 1.5504
Epoch [10/100] Loss: 1.4821
Epoch [10/100] Loss: 1.0400
Epoch [10/100] Loss: 1.4464
Epoch [10/100] Loss: 1.4482
Epoch [10/100] Loss: 1.8135
Epoch [10/100] Loss: 1.0456
Epoch [10/100] Loss: 1.2626
Epoch [10/100] Loss: 1.3130
Epoch [10/100] Loss: 1.6059
Epoch [10/100] Loss: 1.4089
Epoch [10/100] Loss: 1.1815


218it [01:11,  3.05it/s]


Overall Acc: 	 0.515852725217
Mean Acc : 	 0.349110994
FreqW Acc : 	 0.3412173707
Mean IoU : 	 0.23779808084
Epoch [11/100] Loss: 1.3130
Epoch [11/100] Loss: 1.3871
Epoch [11/100] Loss: 1.6470
Epoch [11/100] Loss: 1.3096
Epoch [11/100] Loss: 1.1949
Epoch [11/100] Loss: 1.6572
Epoch [11/100] Loss: 1.6382
Epoch [11/100] Loss: 1.3772
Epoch [11/100] Loss: 1.2069
Epoch [11/100] Loss: 1.3397
Epoch [11/100] Loss: 1.3042
Epoch [11/100] Loss: 1.4985
Epoch [11/100] Loss: 1.2690


218it [01:11,  3.04it/s]


Overall Acc: 	 0.521798198984
Mean Acc : 	 0.361632514882
FreqW Acc : 	 0.349673274913
Mean IoU : 	 0.249651098056
Epoch [12/100] Loss: 1.4482
Epoch [12/100] Loss: 1.3262
Epoch [12/100] Loss: 1.1898
Epoch [12/100] Loss: 1.6979
Epoch [12/100] Loss: 1.0110
Epoch [12/100] Loss: 1.6499
Epoch [12/100] Loss: 1.3692
Epoch [12/100] Loss: 1.5638
Epoch [12/100] Loss: 1.3350
Epoch [12/100] Loss: 1.6023
Epoch [12/100] Loss: 1.9383
Epoch [12/100] Loss: 1.2984
Epoch [12/100] Loss: 1.3308


218it [01:11,  3.04it/s]

Overall Acc: 	 0.523965645109
Mean Acc : 	 0.352666306838
FreqW Acc : 	 0.347256615329
Mean IoU : 	 0.244445404027





Epoch [13/100] Loss: 1.0632
Epoch [13/100] Loss: 1.1770
Epoch [13/100] Loss: 1.3590
Epoch [13/100] Loss: 1.3313
Epoch [13/100] Loss: 1.4503
Epoch [13/100] Loss: 1.3562
Epoch [13/100] Loss: 1.4797
Epoch [13/100] Loss: 1.2401
Epoch [13/100] Loss: 1.8031
Epoch [13/100] Loss: 1.2321
Epoch [13/100] Loss: 1.1918
Epoch [13/100] Loss: 1.2494
Epoch [13/100] Loss: 1.3682


218it [01:11,  3.04it/s]


Overall Acc: 	 0.530057727685
Mean Acc : 	 0.372655877053
FreqW Acc : 	 0.351217836679
Mean IoU : 	 0.25357418145
Epoch [14/100] Loss: 1.2805
Epoch [14/100] Loss: 1.3278
Epoch [14/100] Loss: 1.2821
Epoch [14/100] Loss: 1.4612
Epoch [14/100] Loss: 1.3846
Epoch [14/100] Loss: 1.5522
Epoch [14/100] Loss: 1.3711
Epoch [14/100] Loss: 1.2950
Epoch [14/100] Loss: 1.6120
Epoch [14/100] Loss: 0.9766
Epoch [14/100] Loss: 1.2241
Epoch [14/100] Loss: 1.1125
Epoch [14/100] Loss: 1.2198


218it [01:11,  3.05it/s]


Overall Acc: 	 0.534332010345
Mean Acc : 	 0.378587542262
FreqW Acc : 	 0.359290804424
Mean IoU : 	 0.261446684783
Epoch [15/100] Loss: 1.6128
Epoch [15/100] Loss: 1.5607
Epoch [15/100] Loss: 1.3420
Epoch [15/100] Loss: 1.2568
Epoch [15/100] Loss: 1.3905
Epoch [15/100] Loss: 1.1932
Epoch [15/100] Loss: 1.2905
Epoch [15/100] Loss: 1.2346
Epoch [15/100] Loss: 1.3233
Epoch [15/100] Loss: 1.5609
Epoch [15/100] Loss: 0.9851
Epoch [15/100] Loss: 1.0614
Epoch [15/100] Loss: 1.1138


218it [01:11,  3.05it/s]


Overall Acc: 	 0.539048951564
Mean Acc : 	 0.383915719917
FreqW Acc : 	 0.366106188189
Mean IoU : 	 0.267910152756
Epoch [16/100] Loss: 1.1935
Epoch [16/100] Loss: 1.3729
Epoch [16/100] Loss: 1.5367
Epoch [16/100] Loss: 1.4267
Epoch [16/100] Loss: 0.9676
Epoch [16/100] Loss: 0.9279
Epoch [16/100] Loss: 1.4451
Epoch [16/100] Loss: 0.8235
Epoch [16/100] Loss: 1.1482
Epoch [16/100] Loss: 1.4135
Epoch [16/100] Loss: 1.2251
Epoch [16/100] Loss: 1.0053
Epoch [16/100] Loss: 1.5096


218it [02:09,  1.68it/s]


Overall Acc: 	 0.541687054027
Mean Acc : 	 0.404587814572
FreqW Acc : 	 0.367860661474
Mean IoU : 	 0.277000816107
Epoch [17/100] Loss: 1.1856
Epoch [17/100] Loss: 1.1477
Epoch [17/100] Loss: 1.0401
Epoch [17/100] Loss: 1.3285
Epoch [17/100] Loss: 1.6320
Epoch [17/100] Loss: 1.3527
Epoch [17/100] Loss: 1.2015
Epoch [17/100] Loss: 1.5148
Epoch [17/100] Loss: 1.1982
Epoch [17/100] Loss: 1.2609
Epoch [17/100] Loss: 1.3984
Epoch [17/100] Loss: 1.0227
Epoch [17/100] Loss: 1.1045


24it [00:19,  1.24it/s]