In [1]:
# system libs
import os
import datetime
import argparse
from distutils.version import LooseVersion
# Numerical libs
import numpy as np
import torch
import torch.nn as nn
from scipy.io import loadmat
# Our libs
from dataset import TestDataset
from models import ModelBuilder, SegmentationModule
from utils import colorEncode
from lib.nn import user_scattered_collate, async_copy_to
from lib.utils import as_numpy, mark_volatile
import lib.utils.data as torchdata
import cv2
from broden_dataset_utils.joint_dataset import broden_dataset
from utils import maskrcnn_colorencode, remove_small_mat
import pickle

./broden_dataset/ade20k/ADE20K_2016_07_26/index_ade20k.mat
break point


In [2]:
broden_dataset.nr

{'material': 26, 'object': 336, 'part': 153, 'scene': 365, 'texture': 47}

In [3]:
class args:
    arch_decoder = 'upernet'
    arch_encoder='resnet50'
    batch_size = 1
    fc_dim = 2048
    gpu_id = 0
    imgMaxSize=1000
    imgSize=[300, 400, 500, 600]
    model_path='upp-resnet50-upernet'
    num_class=150
    num_val=-1
    padding_constant=8
    segm_downsampling_rate=8
    suffix='_epoch_40.pth'
    weights_encoder = '/media/emrys/Samsung_T5/research/unifiedparsing/upp-resnet50-upernet/encoder_epoch_40.pth'
    weights_decoder = '/media/emrys/Samsung_T5/research/unifiedparsing/upp-resnet50-upernet/decoder_epoch_40.pth'
    nr_classes = broden_dataset.nr.copy()
    nr_classes['part'] = sum([len(parts) for obj, parts in broden_dataset.object_part.items()])

In [4]:
def save_obj(obj, name):
    with open('/media/emrys/Samsung_T5/research/Data/test/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

In [5]:
files = os.listdir('/media/emrys/Samsung_T5/research/unifiedparsing/infer_data/')

In [6]:
files

['ADE_val_00000001.jpg', 'n01443537_22563.jpg']

In [9]:
def test(segmentation_module, loader, args, name):
    segmentation_module.eval()
    for i, data in enumerate(loader):
        # process data
        data = data[0]
        seg_size = data['img_ori'].shape[0:2]

        with torch.no_grad():
            pred_ms = {}
            for k in ['object', 'material']:
                pred_ms[k] = torch.zeros(1, args.nr_classes[k], *seg_size)
            pred_ms['part'] = []
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                n_part = len(broden_dataset.object_part[object_label])
                pred_ms['part'].append(torch.zeros(1, n_part, *seg_size))
            pred_ms['scene'] = torch.zeros(1, args.nr_classes['scene'])

            for img in data['img_data']:
                # forward pass
                feed_dict = async_copy_to({"img": img}, args.gpu_id)
                pred = segmentation_module(feed_dict, seg_size=seg_size)
                for k in ['scene', 'object', 'material']:
                    pred_ms[k] = pred_ms[k] + pred[k].cpu() / len(args.imgSize)
                for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                    pred_ms['part'][idx_part] += pred['part'][idx_part].cpu() / len(args.imgSize)

            pred_ms['scene'] = pred_ms['scene'].squeeze(0)
            for k in ['object', 'material']:
                _, p_max = torch.max(pred_ms[k].cpu(), dim=1)
                pred_ms[k] = p_max.squeeze(0)
            for idx_part, object_label in enumerate(broden_dataset.object_with_part):
                _, p_max = torch.max(pred_ms['part'][idx_part].cpu(), dim=1)
                pred_ms['part'][idx_part] = p_max.squeeze(0)

            pred_ms = as_numpy(pred_ms)
            save_obj(pred_ms, name)
        print('[{}] iter {}'
              .format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), i))


def main(args):
    torch.cuda.set_device(args.gpu_id)

    # Network Builders
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(
        arch=args.arch_encoder,
        fc_dim=args.fc_dim,
        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(
        arch=args.arch_decoder,
        fc_dim=args.fc_dim,
        nr_classes=args.nr_classes,
        weights=args.weights_decoder,
        use_softmax=True)
    
    segmentation_module = SegmentationModule(net_encoder, net_decoder)
    segmentation_module.cuda()
    
    # Dataset and Loader
    for file in files:
        list_test = [{'fpath_img': '/media/emrys/Samsung_T5/research/unifiedparsing/infer_data/'+file}]
        dataset_val = TestDataset(
            list_test, args, max_sample=args.num_val)
        loader_val = torchdata.DataLoader(
            dataset_val,
            batch_size=args.batch_size,
            shuffle=False,
            collate_fn=user_scattered_collate,
            num_workers=0,
            drop_last=True)
        
        # Main loop
        test(segmentation_module, loader_val, args, file.split('.')[0])

In [10]:
main(args)

Loading weights for net_encoder
Loading weights for net_decoder
# samples: 1
Logcat: save_obj
[2018-11-21 00:29:08] iter 0
# samples: 1
Logcat: save_obj
[2018-11-21 00:29:10] iter 0
