In [1]:
%matplotlib inline
import os
import collections
import torch
import torchvision
import numpy as np
import scipy.misc as m
import scipy.io as io
import matplotlib.pyplot as plt

from NYUDv2Loader import *

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
data_path = '/home/dongwonshin/Desktop/Datasets/NYUDv2/'
arg_string = '--arch RecurrentSegnet --batch_size 3 --n_epoch 50'

# Argument setting

In [2]:
import sys, os
import torch
import visdom
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

from torch.autograd import Variable
from torch.utils import data
from tqdm import tqdm

from ptsemseg.models import get_model
from ptsemseg.loader import get_loader, get_data_path
from ptsemseg.metrics import runningScore
from ptsemseg.loss import *
from ptsemseg.augmentations import *


parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--arch', nargs='?', type=str, default='fcn8s', help='Architecture to use [\'fcn8s, unet, segnet etc\']')
parser.add_argument('--img_rows', nargs='?', type=int, default=256, help='Height of the input image')
parser.add_argument('--img_cols', nargs='?', type=int, default=256, help='Width of the input image')

parser.add_argument('--img_norm', dest='img_norm', action='store_true', help='Enable input image scales normalization [0, 1] | True by default')
parser.add_argument('--no-img_norm', dest='img_norm', action='store_false', help='Disable input image scales normalization [0, 1] | True by default')
parser.set_defaults(img_norm=True)

parser.add_argument('--n_epoch', nargs='?', type=int, default=10, help='# of the epochs')
parser.add_argument('--batch_size', nargs='?', type=int, default=1, help='Batch Size')
parser.add_argument('--l_rate', nargs='?', type=float, default=1e-5, help='Learning Rate')
parser.add_argument('--feature_scale', nargs='?', type=int, default=1, help='Divider for # of features to use')
parser.add_argument('--resume', nargs='?', type=str, default=None, help='Path to previous saved model to restart from')

parser.add_argument('--visdom', dest='visdom', action='store_true', help='Enable visualization(s) on visdom | False by default')
parser.add_argument('--no-visdom', dest='visdom', action='store_false', help='Disable visualization(s) on visdom | False by default')
parser.set_defaults(visdom=False)

# Model init

In [3]:
args = parser.parse_args(arg_string.split(' '))

# Setup Dataloader
t_loader = NYUDv2Loader(data_path, is_transform=True)
v_loader = NYUDv2Loader(data_path, is_transform=True, split='val')

n_classes = t_loader.n_classes
trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=16, shuffle=True)
valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=16)

# Setup Metrics
running_metrics = runningScore(n_classes)

# Setup visdom for visualization
if args.visdom:
    vis = visdom.Visdom()

    loss_window = vis.line(X=torch.zeros((1,)).cpu(),
                       Y=torch.zeros((1)).cpu(),
                       opts=dict(xlabel='minibatches',
                                 ylabel='Loss',
                                 title='Training Loss',
                                 legend=['Loss']))

# Setup Model
model = get_model(args.arch, n_classes)

model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
model.cuda()

# Check if model has custom optimizer / loss
if hasattr(model.module, 'optimizer'):
    optimizer = model.module.optimizer
else:
    optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)

if hasattr(model.module, 'loss'):
    print('Using custom loss')
    loss_fn = model.module.loss
else:
    loss_fn = cross_entropy2d

if args.resume is not None:                                         
    if os.path.isfile(args.resume):
        print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        print("Loaded checkpoint '{}' (epoch {})"                    
              .format(args.resume, checkpoint['epoch']))
    else:
        print("No checkpoint found at '{}'".format(args.resume)) 


# Training

In [None]:
best_iou = -100.0 
for epoch in range(args.n_epoch):
    
    # train
    model.train()
    for i, (color_imgs, depth_imgs, label_imgs) in enumerate(trainloader):
        color_imgs = Variable(color_imgs.cuda())
#         depth_imgs = Variable(depth_imgs.cuda())
        label_imgs = Variable(label_imgs.cuda())

        optimizer.zero_grad()
        outputs = model(color_imgs)

        loss = loss_fn(input=outputs, target=label_imgs)

        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print("Epoch [%d/%d] Loss: %.4f" % (epoch+1, args.n_epoch, loss.data[0]))

    # eval
    model.eval()
    for i_val, (color_images_val, depth_images_val, label_images_val) in tqdm(enumerate(valloader)):
        color_images_val = Variable(color_images_val.cuda(), volatile=True)
#         depth_images_val = Variable(depth_images_val.cuda(), volatile=True)
        label_images_val = Variable(label_images_val.cuda(), volatile=True)

        outputs = model(color_images_val)
        pred = outputs.data.max(1)[1].cpu().numpy()
        gt = label_images_val.data.cpu().numpy()
        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()
    for k, v in score.items():
        print(k, v)
    running_metrics.reset()

    # model save
    if score['Mean IoU : \t'] >= best_iou:
        best_iou = score['Mean IoU : \t']
        state = {'epoch': epoch+1,
                 'model_state': model.state_dict(),
                 'optimizer_state' : optimizer.state_dict(),}
        torch.save(state, "../model_weights/{}_{}_best_model.pkl".format(args.arch, 'NYUDv2'))

Epoch [1/50] Loss: 2.6149
Epoch [1/50] Loss: 2.6544


218it [00:29,  7.42it/s]


Overall Acc: 	 0.156668189746
Mean Acc : 	 0.084685320527
FreqW Acc : 	 0.0870345832956
Mean IoU : 	 0.0408465968271
Epoch [2/50] Loss: 2.5172
Epoch [2/50] Loss: 2.5271


218it [00:29,  7.36it/s]


Overall Acc: 	 0.216477660511
Mean Acc : 	 0.0970628441916
FreqW Acc : 	 0.108996780255
Mean IoU : 	 0.0466214067866
Epoch [3/50] Loss: 2.4958
Epoch [3/50] Loss: 2.5207


218it [00:29,  7.32it/s]


Overall Acc: 	 0.222765936584
Mean Acc : 	 0.103240601082
FreqW Acc : 	 0.12075845367
Mean IoU : 	 0.0515171132422
Epoch [4/50] Loss: 2.3988
Epoch [4/50] Loss: 2.5458


218it [00:29,  7.41it/s]


Overall Acc: 	 0.240141118756
Mean Acc : 	 0.106871269097
FreqW Acc : 	 0.126093210672
Mean IoU : 	 0.0533969770104
Epoch [5/50] Loss: 2.5315
Epoch [5/50] Loss: 2.3755


218it [00:29,  7.39it/s]


Overall Acc: 	 0.255875740635
Mean Acc : 	 0.11038296406
FreqW Acc : 	 0.134783366964
Mean IoU : 	 0.0565322457547
Epoch [6/50] Loss: 2.4305
Epoch [6/50] Loss: 2.3820


218it [00:29,  7.30it/s]


Overall Acc: 	 0.272727222501
Mean Acc : 	 0.116720093746
FreqW Acc : 	 0.150640191606
Mean IoU : 	 0.0625050004202
Epoch [7/50] Loss: 2.3109
Epoch [7/50] Loss: 2.3057


218it [00:29,  7.34it/s]


Overall Acc: 	 0.294803124602
Mean Acc : 	 0.124914397456
FreqW Acc : 	 0.167470720906
Mean IoU : 	 0.0687745847732
Epoch [8/50] Loss: 2.4028
Epoch [8/50] Loss: 2.2912


218it [00:29,  7.39it/s]


Overall Acc: 	 0.311159486294
Mean Acc : 	 0.130466031518
FreqW Acc : 	 0.177392143471
Mean IoU : 	 0.0725482976044
Epoch [9/50] Loss: 2.3859
Epoch [9/50] Loss: 2.3991


218it [00:29,  7.32it/s]


Overall Acc: 	 0.317765782285
Mean Acc : 	 0.132250390673
FreqW Acc : 	 0.179885341888
Mean IoU : 	 0.0734542672947
Epoch [10/50] Loss: 2.2707
Epoch [10/50] Loss: 2.1223


218it [00:29,  7.37it/s]


Overall Acc: 	 0.321658414166
Mean Acc : 	 0.133818033038
FreqW Acc : 	 0.183737602849
Mean IoU : 	 0.0751050272023
Epoch [11/50] Loss: 2.0392
Epoch [11/50] Loss: 2.3097


218it [00:29,  7.37it/s]


Overall Acc: 	 0.326477227478
Mean Acc : 	 0.134772115313
FreqW Acc : 	 0.186729819388
Mean IoU : 	 0.0760577136259
Epoch [12/50] Loss: 2.2926
Epoch [12/50] Loss: 2.3511


218it [00:29,  7.36it/s]


Overall Acc: 	 0.3321646787
Mean Acc : 	 0.135707552712
FreqW Acc : 	 0.186862562117
Mean IoU : 	 0.0764688100716
Epoch [13/50] Loss: 2.3254
Epoch [13/50] Loss: 2.3199


218it [00:29,  7.40it/s]


Overall Acc: 	 0.336551430301
Mean Acc : 	 0.139285299088
FreqW Acc : 	 0.194510809498
Mean IoU : 	 0.0796663851349
Epoch [14/50] Loss: 2.1684
Epoch [14/50] Loss: 2.2823


218it [00:29,  7.28it/s]


Overall Acc: 	 0.342568080642
Mean Acc : 	 0.142637350207
FreqW Acc : 	 0.197314288065
Mean IoU : 	 0.081524338663
Epoch [15/50] Loss: 2.1639
Epoch [15/50] Loss: 2.3122


218it [00:29,  7.28it/s]


Overall Acc: 	 0.349571049153
Mean Acc : 	 0.147296308964
FreqW Acc : 	 0.206719290882
Mean IoU : 	 0.0855897493619
Epoch [16/50] Loss: 2.2830
Epoch [16/50] Loss: 2.2703


218it [00:29,  7.32it/s]


Overall Acc: 	 0.357822743454
Mean Acc : 	 0.151682934197
FreqW Acc : 	 0.206334458541
Mean IoU : 	 0.0868140974854
Epoch [17/50] Loss: 2.1897
Epoch [17/50] Loss: 1.9763


218it [00:29,  7.36it/s]


Overall Acc: 	 0.365943268787
Mean Acc : 	 0.15676580047
FreqW Acc : 	 0.211833583515
Mean IoU : 	 0.0897685946988
Epoch [18/50] Loss: 2.0034
Epoch [18/50] Loss: 2.1121


218it [00:29,  7.28it/s]


Overall Acc: 	 0.372573744903
Mean Acc : 	 0.16100103662
FreqW Acc : 	 0.212265316486
Mean IoU : 	 0.0909579400604
Epoch [19/50] Loss: 1.8738
Epoch [19/50] Loss: 2.2054


218it [00:29,  7.36it/s]


Overall Acc: 	 0.383880054034
Mean Acc : 	 0.16669344854
FreqW Acc : 	 0.219071474816
Mean IoU : 	 0.0944927510152
Epoch [20/50] Loss: 2.1757
Epoch [20/50] Loss: 1.9657


218it [00:29,  7.35it/s]


Overall Acc: 	 0.387490144782
Mean Acc : 	 0.168942902635
FreqW Acc : 	 0.222691321448
Mean IoU : 	 0.0964443410326
Epoch [21/50] Loss: 2.2005
Epoch [21/50] Loss: 1.9877


218it [00:29,  7.34it/s]


Overall Acc: 	 0.394108909117
Mean Acc : 	 0.172052021729
FreqW Acc : 	 0.22309946743
Mean IoU : 	 0.0970372520763
Epoch [22/50] Loss: 2.1816
Epoch [22/50] Loss: 2.0314


218it [00:29,  7.31it/s]


Overall Acc: 	 0.39601263857
Mean Acc : 	 0.173310641048
FreqW Acc : 	 0.22683549146
Mean IoU : 	 0.0981562496683
Epoch [23/50] Loss: 1.9110
Epoch [23/50] Loss: 2.1921


218it [00:29,  7.29it/s]


Overall Acc: 	 0.400059300538
Mean Acc : 	 0.175050198145
FreqW Acc : 	 0.22791763259
Mean IoU : 	 0.099506719699
Epoch [24/50] Loss: 2.1901
Epoch [24/50] Loss: 2.0321


218it [00:29,  7.31it/s]


Overall Acc: 	 0.403579783464
Mean Acc : 	 0.178847915621
FreqW Acc : 	 0.235558741237
Mean IoU : 	 0.102376315001
Epoch [25/50] Loss: 1.9764
Epoch [25/50] Loss: 2.0062


218it [00:29,  7.35it/s]

Overall Acc: 	 0.40788193449
Mean Acc : 	 0.179319547305
FreqW Acc : 	 0.233181687388
Mean IoU : 	 0.101690514413





Epoch [26/50] Loss: 1.8165
Epoch [26/50] Loss: 2.3798


218it [00:29,  7.40it/s]


Overall Acc: 	 0.410613457449
Mean Acc : 	 0.181818242739
FreqW Acc : 	 0.238798237617
Mean IoU : 	 0.104147148549
Epoch [27/50] Loss: 2.1433
Epoch [27/50] Loss: 1.8515


218it [00:29,  7.27it/s]

Overall Acc: 	 0.410929655645
Mean Acc : 	 0.181600419686
FreqW Acc : 	 0.237620539099
Mean IoU : 	 0.10366585425





Epoch [28/50] Loss: 1.7702
Epoch [28/50] Loss: 1.7916


218it [00:30,  7.12it/s]


Overall Acc: 	 0.415438213757
Mean Acc : 	 0.185304801159
FreqW Acc : 	 0.242481519947
Mean IoU : 	 0.106236526557
Epoch [29/50] Loss: 1.9415
Epoch [29/50] Loss: 1.8394


218it [00:29,  7.32it/s]

Overall Acc: 	 0.416164836981
Mean Acc : 	 0.184382714763
FreqW Acc : 	 0.241376106843
Mean IoU : 	 0.105101995706





Epoch [30/50] Loss: 1.6435
Epoch [30/50] Loss: 1.9042


218it [00:29,  7.30it/s]


Overall Acc: 	 0.41818931774
Mean Acc : 	 0.185195790058
FreqW Acc : 	 0.243012366448
Mean IoU : 	 0.106394917543
Epoch [31/50] Loss: 2.1440
Epoch [31/50] Loss: 2.0485


218it [00:29,  7.30it/s]


Overall Acc: 	 0.418273579853
Mean Acc : 	 0.186739637304
FreqW Acc : 	 0.24404056617
Mean IoU : 	 0.106536347191
