In [None]:
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from easydict import EasyDict
import _init_paths
import os
import sys
import numpy as np
import argparse
import pprint
import pdb
import time
import cv2
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import xml.dom.minidom as minidom
import PIL
from PIL import Image
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
try:
    import cPickle as pickle
    print('import cPickle')
except:
    import pickle
    print('import python pickle')
from roi_data_layer.roidb import combined_roidb
from roi_data_layer.roibatchLoader import roibatchLoader
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.rpn.bbox_transform import clip_boxes
from model.nms.nms_wrapper import nms
from model.rpn.bbox_transform import bbox_transform_inv
from model.utils.net_utils import save_net, load_net, vis_detections
from model.faster_rcnn.vgg16 import vgg16
from model.faster_rcnn.resnet import resnet
#from model.faster_rcnn.faster_rcnn import _fasterRCNN
from model.siamese_net.siameseRCNN import _siameseRCNN
from model.siamese_net.weight_cropping_layer import weight_crop_layer

import pdb

try:
    xrange          # Python 2
except NameError:
    xrange = range  # Python 3

torch.cuda.set_device(0)
def parse_args():
    """
    Parse input arguments
    """
    args = EasyDict()
    args['dataset'] = 'imagenetVID_PLUS'
    args['net'] = 'res101'
    args['load_dir'] = 'models'
    args['cuda'] = True
    args['vid_size'] = 1
    args['class_agnostic'] = False
    args['cfg_file'] = 'cfgs/res101_lighthead_siam.yml'
    args['ckpt'] = '1_8_54269'
    return args

def bbox_delta_to_pred_boxes(im_info, boxes, bbox_pred):
    box_deltas = bbox_pred.data
    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
        # Optionally normalize targets by a precomputed mean and stdev
        if args.class_agnostic:
            box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                         + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
            box_deltas = box_deltas.view(1, -1, 4)
        else:
            box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                         + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
            box_deltas = box_deltas.view(1, -1, 4 * len(imdb.classes))

    pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
    pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
    return pred_boxes

weight_cropper = weight_crop_layer().cuda()
def siam_weights_preparation(rois_tracking, base_feat):
    if rois_tracking is None:
        return None, None
    else:
        rois_tracking = Variable(base_feat.new_tensor(rois_tracking))
        boxes = rois_tracking[:,:4]
        batch_inds = boxes.new_zeros((boxes.size(0),1))
        boxes = torch.cat((batch_inds, boxes),dim=1)
        template_weights = weight_cropper(base_feat, boxes)
        return template_weights, rois_tracking

def prepare_rois_tracking(im_info, all_boxes, all_boxes_scores, frame_id, class_num, thresh=cfg.SIAMESE.THRESH_FOR_TRACKING):
    # class_num is 31 for imagenetVID.
    sel_boxes = []
    for j in range(1, class_num):
        if len(all_boxes[j][frame_id]) == 0:
            continue
        scored_boxes = all_boxes[j][frame_id].copy()
        scores = all_boxes_scores[j][frame_id].copy()
        assert len(scored_boxes)==len(scores), 'length of scored_boxes and length of scores should be the equal.'
        # TODO comment out the following for loop to accelerate predictions.
        scored_boxes[:, :4] = scored_boxes[:, :4] * im_info[-1]
        for b_id in range(len(scored_boxes)):
            assert scored_boxes[b_id, 4] == scores[b_id, j], 'scores not matched, please check your code.%f!=%f'%(scored_boxes[b_id, 4],scores[b_id, j])
        inds = np.where(scored_boxes[:, 4]>thresh)[0]
        if len(inds)>0:
            sel_cls_boxes = np.concatenate((scored_boxes[inds,:4], scores[inds,:]), axis=1)
            sel_boxes.append(sel_cls_boxes)
        else:
            continue
    if len(sel_boxes)>0:
        rois_tracking = np.concatenate(sel_boxes, axis=0)
    else:
        rois_tracking = None
    return rois_tracking

In [None]:
args = parse_args()

print('Called with args:')
print(args)

if torch.cuda.is_available() and not args.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

np.random.seed(cfg.RNG_SEED)
if args.dataset == "imagenetVID":
    args.imdb_name = 'imagenetVID_train'
    args.imdbval_name = 'imagenetVID_val'
    args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '30']
elif args.dataset == "imagenetVID_PLUS":
    args.imdb_name = 'imagenetVID_PLUS_train'
    args.imdbval_name = 'imagenetVID_PLUS_val'
    args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '30']
elif args.dataset == "imagenetVID_1_vid":
    args.imdb_name = 'imagenetVID_1_vid_train'
    # TODO imdbval is set to train set now.
    args.imdbval_name = 'imagenetVID_1_vid_train'
    args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '30']

if args.cfg_file is None:
    args.cfg_file = "cfgs/{}_ls.yml".format(args.net) if args.large_scale else "cfgs/{}.yml".format(args.net)

if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
    cfg_from_list(args.set_cfgs)

print('Using config:')
pprint.pprint(cfg)

cfg.TRAIN.USE_FLIPPED = False
imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdbval_name, False)
imdb.competition_mode(on=True)

print('{:d} roidb entries'.format(len(roidb)))

input_dir = args.load_dir + "/" + args.net + "/" + args.dataset
if not os.path.exists(input_dir):
    raise Exception('There is no input directory for loading network from ' + input_dir)

#print('cfg.RESNET.CORE_CHOICE.USE:',cfg.RESNET.CORE_CHOICE.USE)
load_name_predix = cfg.RESNET.CORE_CHOICE.USE + '_siam'
if cfg.TRAIN.OHEM is True:
    load_name_predix = load_name_predix+'_OHEM'
load_name = os.path.join(input_dir, load_name_predix+'_{}.pth'.format(args.ckpt))

# initilize the network here.
if args.net == 'res101':
    RCNN = _siameseRCNN(imdb.classes, args)
else:
    print("network is not defined")
    pdb.set_trace()

print("load checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
RCNN.load_state_dict(checkpoint['model'])
if 'pooling_mode' in checkpoint.keys():
    cfg.POOLING_MODE = checkpoint['pooling_mode']


print('load model successfully!')
# initilize the tensor holder here.
im_data = torch.FloatTensor(1)
im_info = torch.FloatTensor(1)
num_boxes = torch.LongTensor(1)
gt_boxes = torch.FloatTensor(1)

# ship to cuda
if args.cuda:
    im_data = im_data.cuda()
    im_info = im_info.cuda()
    num_boxes = num_boxes.cuda()
    gt_boxes = gt_boxes.cuda()

# make variable
im_data = Variable(im_data)
im_info = Variable(im_info)
num_boxes = Variable(num_boxes)
gt_boxes = Variable(gt_boxes)

if args.cuda:
    cfg.CUDA = True

if args.cuda:
    RCNN.cuda()

start = time.time()
max_per_image = 100

thresh = 0.01

#save_name = 'light_head_rcnn_10'
save_name = load_name_predix
num_images = len(imdb.image_index)

output_dir = get_output_dir(imdb, save_name)
dataset = roibatchLoader(roidb, ratio_list, ratio_index, 1, \
                    imdb.num_classes, training=False, normalize = False)
#dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
#                        shuffle=False, num_workers=0,
#                        pin_memory=True)

In [None]:
import scipy.io as sio
devkit_path = 'data/imagenet/ILSVRC/devkit'
####VID object info####
synsets_video = sio.loadmat(os.path.join(devkit_path,'data','meta_vid.mat'))
_wnid = (0,)
_classes = ('__background__',)
for i in xrange(30):
    _classes = _classes + (synsets_video['synsets'][0][i][2][0],)
    _wnid = _wnid + (synsets_video['synsets'][0][i][1][0],)

_wnid_to_ind = dict(zip(_wnid, xrange(31)))
_class_to_ind = dict(zip(_classes, xrange(31)))
#######################
def read_annotation(dataPath):
    filename = dataPath.replace('Data','Annotations').replace('JPEG','xml')
    assert os.path.exists(filename),'%s'%(filename)
    # print 'Loading: {}'.format(filename)
    def get_data_from_tag(node, tag):
        return node.getElementsByTagName(tag)[0].childNodes[0].data

    with open(filename) as f:
        data = minidom.parseString(f.read())

    objs = data.getElementsByTagName('object')
    num_objs = len(objs)
    
    boxes = np.zeros((num_objs, 4), dtype=np.int32)
    gt_classes = np.zeros(num_objs,dtype=np.int32)
    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):
        x1 = float(get_data_from_tag(obj, 'xmin'))
        y1 = float(get_data_from_tag(obj, 'ymin'))
        x2 = float(get_data_from_tag(obj, 'xmax'))
        y2 = float(get_data_from_tag(obj, 'ymax'))
        cls = _wnid_to_ind[
                str(get_data_from_tag(obj, "name")).lower().strip()]
        boxes[ix, :] = [x1, y1, x2, y2]
        gt_classes[ix] = cls
    return boxes, gt_classes

In [None]:
def showBoxResult(img_id, all_boxes, imdb, show_class=True, threshold=0.3, dpi = 200,all_boxes_tracking=None,all_boxes_detection=None):
    imIdx = imdb.image_index[img_id]+'.JPEG'
    assert os.path.exists(imIdx), imIdx+' does not exist.'
    #print(imIdx)
    gt_boxes, gt_classes = read_annotation(imIdx)
    im2show = np.array(PIL.Image.open(imIdx))
    fig = plt.figure(frameon=False)
    fig.set_size_inches(im2show.shape[1] / dpi, im2show.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im2show)
    
    #draw gt first
    for j in range(len(gt_classes)):
        ax.add_patch(
        plt.Rectangle((gt_boxes[j][0], gt_boxes[j][1]),
                      gt_boxes[j][2] - gt_boxes[j][0],
                      gt_boxes[j][3] - gt_boxes[j][1],
                      fill=False, edgecolor='g',
                      linewidth=2.0, alpha=0.8))
        if show_class:
            ax.text(gt_boxes[j][0], gt_boxes[j][1] - 2,
            '%s %f'%(imdb.classes[gt_classes[j]], 1.0),
            fontsize=10,
            #family='serif',
            bbox=dict(facecolor='g', alpha=0.5, pad=0, edgecolor='none'),
                color='white')
 
    #draw dets
    for j in range(len(imdb.classes)):
        if isinstance(all_boxes[j][img_id], np.ndarray):
            for bbox in all_boxes[j][img_id]:
                if bbox[-1]>threshold:
                    ax.add_patch(
                    plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1],
                              fill=False, edgecolor='y',
                              linewidth=2.0, alpha=0.4))
                    if show_class:
                        ax.text(bbox[0], bbox[1] - 2,
                        '%s %f'%(imdb.classes[j], bbox[-1]),
                        fontsize=10,
                        #family='serif',
                        bbox=dict(
                            facecolor='y', alpha=0.5, pad=0, edgecolor='none'),
                            color='white')
            #im2show = vis_detections(im2show, imdb.classes[j], all_boxes[j][img_id], 0.3)
    #plt.show()
    if all_boxes_detection is not None:
        for j in range(len(imdb.classes)):
            if isinstance(all_boxes_detection[j][img_id], np.ndarray):
                for bbox in all_boxes_detection[j][img_id]:
                    if bbox[-1]>threshold:
                        ax.add_patch(
                        plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1],
                                  fill=False, edgecolor='m',
                                  linewidth=2.0, alpha=0.2))
                        if show_class:
                            ax.text(bbox[0], bbox[1] - 2,
                            '%s %f'%(imdb.classes[j], bbox[-1]),
                            fontsize=10,
                            #family='serif',
                            bbox=dict(
                                facecolor='m', alpha=0.2, pad=0, edgecolor='none'),
                                color='white')
    
    if all_boxes_tracking is not None:
        for j in range(len(imdb.classes)):
            if isinstance(all_boxes_tracking[j][img_id], np.ndarray):
                for bbox in all_boxes_tracking[j][img_id]:
                    if bbox[-1]>threshold:
                        ax.add_patch(
                        plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1],
                                  fill=False, edgecolor='w',
                                  linewidth=2.0, alpha=0.4))
                        if show_class:
                            ax.text(bbox[0], bbox[1] - 2,
                            '%s %f'%(imdb.classes[j], bbox[-1]),
                            fontsize=10,
                            #family='serif',
                            bbox=dict(
                                facecolor='w', alpha=0.5, pad=0, edgecolor='none'),
                                color='white')
    return fig

In [None]:
def makeVid(imdb,all_boxes,img_ids,vid_id,dname_prefix='output/tmp_vid',dir_suffix='',ext='png',all_boxes_tracking=None,all_boxes_detection=None):
    name_ls = []
    dname = dname_prefix+'_%03d'%(vid_id)+dir_suffix
    if not os.path.exists(dname):
        os.mkdir(dname)
        assert os.path.exists(dname)
    for i in img_ids:
        fname = os.path.join(dname,'%08d'%(i)+'.'+ext)
        fig = showBoxResult(i,all_boxes,imdb,all_boxes_tracking=all_boxes_tracking,all_boxes_detection=all_boxes_detection)
        fig.savefig(fname)
        plt.close(fig)
        name_ls.append(fname)
    return name_ls

In [None]:
all_boxes = [[[] for _ in xrange(num_images)]
           for _ in xrange(imdb.num_classes)]
all_boxes_scores = [[[] for _ in xrange(num_images)]
                   for _ in xrange(imdb.num_classes)]
SEE_TRACKING=True
SEE_DETECTION=True
if SEE_TRACKING:
    all_boxes_tracking = [[[] for _ in xrange(num_images)]
               for _ in xrange(imdb.num_classes)]
if SEE_DETECTION:
    all_boxes_detection = [[[] for _ in xrange(num_images)]
               for _ in xrange(imdb.num_classes)]

In [None]:
from PyPDF2 import PdfFileMerger
def SEE_VIDS(vids_id_list, MODE=0):
    # MODE: 0 det+tra; 1 det; 2 tra
    for SEE_VID_ID in vids_id_list:
        #data_iter = iter(dataloader)

        _t = {'im_detect': time.time(), 'misc': time.time()}
        det_file = os.path.join(output_dir, 'detections.pkl')

        RCNN.eval()
        empty_array = np.transpose(np.array([[],[],[],[],[]]), (1,0))

        template_weights = None
        rois_tracking = None


        #for i in range(num_images):
        for i in range(imdb._structured_indexes[SEE_VID_ID][0],imdb._structured_indexes[SEE_VID_ID][-1]+1):
            data = dataset.__getitem__(i)
            sys.stdout.write('Processing vid %d: %d/%d. \r' % (SEE_VID_ID, i-imdb._structured_indexes[SEE_VID_ID][0]+1,
                                                 len(imdb._structured_indexes[SEE_VID_ID])))
            sys.stdout.flush()
            data = list(data)
            data[0] = data[0].unsqueeze(0)
            data[1] = data[1].unsqueeze(0)
            data = tuple(data)
            #data[2] = data[2].unsqueeze(0)
            im_data.data.resize_(data[0].size()).copy_(data[0])
            im_info.data.resize_(data[1].size()).copy_(data[1])
            #gt_boxes.data.resize_(data[2].size()).copy_(data[2])
            #num_boxes.data.resize_(data[3].size()).copy_(data[3])

            input = im_data, im_info, template_weights, rois_tracking

            det_tic = time.time()
            siam_rois, siam_bbox_pred, siam_cls_prob, rois, rois_label, cls_prob, bbox_pred = RCNN(input)

            scores = None
            pred_boxes = None
            ###########################################
            # Get detection boxes.
            ###########################################
            if cfg.TEST.BBOX_REG:
                if rois is not None:
                    boxes = rois.data[:, :, 1:5]
                    pred_boxes = bbox_delta_to_pred_boxes(im_info, boxes, bbox_pred)

                    pred_rois = pred_boxes[:,4:].view(-1,4)
                    pred_rois = torch.cat((pred_boxes.new_zeros(pred_rois.size(0),1),pred_rois),dim=1).unsqueeze(0)
                    bbox_pred, cls_prob, cls_score = RCNN.RCNN.base_feat_to_roi_pred(RCNN.RCNN.base_feat_for_roi, pred_rois, None)
                    boxes = pred_rois.data[:, :, 1:5]
                    pred_boxes = bbox_delta_to_pred_boxes(im_info, boxes, bbox_pred)
                    scores = cls_prob.data

                    pred_boxes /= data[1][0][2].item()
                    scores = scores.squeeze()
                    pred_boxes = pred_boxes.squeeze()
                    ###TODO###
                    # For debug only.
                    if SEE_DETECTION:
                        for j in xrange(1, imdb.num_classes):
                            inds = torch.nonzero(scores[:,j]>0).view(-1)
                            # if there is det
                            if inds.numel() > 0:
                                cls_scores = scores[:,j][inds]
                                _, order = torch.sort(cls_scores, 0, True)
                                if args.class_agnostic:
                                    cls_boxes = pred_boxes[inds, :]
                                else:
                                    cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                                cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                                #cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                                cls_dets = cls_dets[order]
                                ######### nms for each cls here ########
                                keep = nms(cls_dets, cfg.TEST.NMS)
                                cls_dets = cls_dets[keep.view(-1).long()]
                                all_boxes_detection[j][i] = cls_dets.cpu().numpy()
                            else:
                                all_boxes_detection[j][i] = empty_array
                    
                if siam_bbox_pred is not None:
                    siam_scores = siam_cls_prob.data
                    siam_boxes = siam_rois.data[:, 1:5]
                    pred_siam_bbox = bbox_delta_to_pred_boxes(im_info, siam_boxes.unsqueeze(0), siam_bbox_pred.unsqueeze(0))                                                            
                    pred_siam_bbox /= data[1][0][2].item()
                    pred_siam_bbox = pred_siam_bbox.squeeze(0)
                    # concatenate siambox and detbox.
                    if MODE==0:
                        if rois is not None:
                            pred_boxes = torch.cat((pred_boxes, pred_siam_bbox), 0)
                            scores = torch.cat((scores, siam_scores), 0)
                        else:
                            pred_boxes = pred_siam_bbox
                            scores = siam_scores
                    #####################
                    if MODE==2:
                        #####################
                        #pred_boxes = pred_siam_bbox
                        #####################
                        pred_boxes = siam_boxes.repeat(1,siam_boxes.size(1)*31)
                        pred_boxes = pred_boxes/im_info[0][-1]
                        scores = siam_scores
                    ###TODO###
                    # For debug only.
                    if SEE_TRACKING:
                        for j in xrange(1, imdb.num_classes):
                            inds = torch.nonzero(siam_scores[:,j]>0).view(-1)
                            # if there is det
                            if inds.numel() > 0:
                                cls_scores = siam_scores[:,j][inds]
                                #_, order = torch.sort(cls_scores, 0, True)
                                if args.class_agnostic:
                                    cls_boxes = pred_siam_bbox[inds, :]
                                else:
                                    cls_boxes = pred_siam_bbox[inds][:, j * 4:(j + 1) * 4]
                                cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                                # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                                #cls_dets = cls_dets[order]
                                ######### nms for each cls here ########
                                #keep = nms(cls_dets, cfg.TEST.NMS)
                                #cls_dets = cls_dets[keep.view(-1).long()]
                                all_boxes_tracking[j][i] = cls_dets.cpu().numpy()
                            else:
                                all_boxes_tracking[j][i] = empty_array
            else:
                raise ValueError('Error. Should set cfg.TEST.BBOX_REG to True.')

            det_toc = time.time()
            detect_time = det_toc - det_tic
            misc_tic = time.time()
            ###########################################
            # NMS for detection and save to all boxes.
            ###########################################
            if scores is not None:
                for j in xrange(1, imdb.num_classes):
                    inds = torch.nonzero(scores[:,j]>thresh).view(-1)
                    # if there is det
                    if inds.numel() > 0:
                        cls_scores = scores[:,j][inds]
                        all_scores = scores[inds]
                        _, order = torch.sort(cls_scores, 0, True)
                        if args.class_agnostic:
                            cls_boxes = pred_boxes[inds, :]
                        else:
                            cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                        cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                        # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                        cls_dets = cls_dets[order]
                        all_scores = all_scores[order]
                        ######### nms for each cls here ########
                        keep = nms(cls_dets, cfg.TEST.NMS)
                        cls_dets = cls_dets[keep.view(-1).long()]
                        all_cls_scores = all_scores[keep.view(-1).long()]
                        all_boxes[j][i] = cls_dets.cpu().numpy()
                        all_boxes_scores[j][i] = all_cls_scores.cpu().numpy()
                    else:
                        all_boxes[j][i] = empty_array
                        all_boxes_scores[j][i] = empty_array
            else:
                for j in xrange(1, imdb.num_classes):
                    all_boxes[j][i] = empty_array
                    all_boxes_scores[j][i] = empty_array

            # Limit to max_per_image detections *over all classes*
            if max_per_image > 0:
                image_scores = np.hstack([all_boxes[j][i][:, -1]
                                        for j in xrange(1, imdb.num_classes)])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in xrange(1, imdb.num_classes):
                        keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
                        all_boxes[j][i] = all_boxes[j][i][keep, :]
                        all_boxes_scores[j][i] = all_boxes_scores[j][i][keep, :]

            # nms between classes.
            if cfg.TEST.NMS_CROSS_CLASS > 0.:
                cls_boxes = [[]]+[all_boxes[j][i] for j in range(1, imdb.num_classes)]
                cls_boxes_scores = [[]]+[all_boxes_scores[j][i] for j in range(1, imdb.num_classes)]
                cls_box_list = [cls_boxes[j] for j in range(1, imdb.num_classes) if len(cls_boxes[j])>0]
                if len(cls_box_list)>0:
                    all_dets = np.vstack(cls_box_list)
                    all_dets_scores = np.vstack([cls_boxes_scores[j] for j in range(1, imdb.num_classes) if len(cls_boxes[j])>0])
                    class_ids = np.vstack(
                        [np.ones(shape=(len(cls_boxes[j]), 1))*j for j in range(1, imdb.num_classes) if len(cls_boxes[j])>0])
                    _inds = np.argsort(-all_dets[:,-1])
                    all_dets = all_dets[_inds,:]
                    all_dets_scores = all_dets_scores[_inds,:]
                    class_ids = class_ids[_inds,:]

                    keep = nms(torch.tensor(all_dets.astype(np.float32)).cuda(), cfg.TEST.NMS_CROSS_CLASS).view(-1).long().cpu().numpy()
                    all_dets = all_dets[keep, :]
                    all_dets_scores = all_dets_scores[keep, :]
                    class_ids = class_ids[keep, :]
                    for j in range(1, imdb.num_classes):
                        idx_j = np.where(class_ids==j)[0]
                        all_boxes[j][i] = all_dets[idx_j, :]
                        all_boxes_scores[j][i] = all_dets_scores[idx_j, :]

            ########
            # Get weights for the next iteration.
            ########
            # First, convert all_boxes to rois_tracking.#
            # TODO change threshold.
            rois_tracking = prepare_rois_tracking(im_info[0], all_boxes, all_boxes_scores, frame_id=i,
                                                class_num=imdb.num_classes, thresh=0.9)
            base_feat = RCNN.track_feat_trans_1.cuda()(RCNN.RCNN.Conv_feat_track)
            template_weights, rois_tracking = siam_weights_preparation(rois_tracking, base_feat)

            misc_toc = time.time()
            nms_time = misc_toc - misc_tic
            
            '''
            sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s   \r' \
              .format(i + 1, num_images, detect_time, nms_time))
            sys.stdout.flush()
            '''

        #with open(det_file, 'wb') as f:
            #pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

        #print('Evaluating detections')
        #imdb.evaluate_detections(all_boxes, output_dir)

        end = time.time()
        print('')
        print("test time: %0.4fs" % (end - start))
        
        img_ids = imdb._structured_indexes[SEE_VID_ID]
        dir_suffix = ''
        if MODE==1:
            dir_suffix='_det'
        elif MODE==2:
            dir_suffix='_tra'
        name_ls = makeVid(imdb,all_boxes,img_ids,SEE_VID_ID,
                          dname_prefix='output/tmp_vid',
                          dir_suffix=dir_suffix,ext='pdf',
                          all_boxes_tracking=all_boxes_tracking,
                          all_boxes_detection=all_boxes_detection)
        
        pdfs = name_ls
        merger = PdfFileMerger()

        for pdf in pdfs:
            merger.append(open(pdf, 'rb'))

        with open(os.path.join(os.path.dirname(pdfs[0]),'result.pdf'), 'wb') as fout:
            merger.write(fout)
            
        print(os.path.join(os.path.dirname(pdfs[0]),'result.pdf'))

In [None]:
VIDS_IDS = list(range(101,300))
SEE_VIDS(VIDS_IDS)

#### Below are more detailed control.

In [None]:
# 0 det+tra; 1 det; 2 tra
MODE = 0
SEE_TRACKING=True
SEE_DETECTION=True

In [None]:
all_boxes = [[[] for _ in xrange(num_images)]
           for _ in xrange(imdb.num_classes)]
all_boxes_scores = [[[] for _ in xrange(num_images)]
                   for _ in xrange(imdb.num_classes)]
if SEE_TRACKING:
    all_boxes_tracking = [[[] for _ in xrange(num_images)]
               for _ in xrange(imdb.num_classes)]
if SEE_DETECTION:
    all_boxes_detection = [[[] for _ in xrange(num_images)]
               for _ in xrange(imdb.num_classes)]

In [None]:
SEE_VID_ID = 23

In [None]:
#data_iter = iter(dataloader)

_t = {'im_detect': time.time(), 'misc': time.time()}
det_file = os.path.join(output_dir, 'detections.pkl')

RCNN.eval()
empty_array = np.transpose(np.array([[],[],[],[],[]]), (1,0))

template_weights = None
rois_tracking = None


#for i in range(num_images):
for i in range(imdb._structured_indexes[SEE_VID_ID][0],imdb._structured_indexes[SEE_VID_ID][-1]+1):
    data = dataset.__getitem__(i)
    print('Processing vid %d: %d/%d.' % (SEE_VID_ID, i-imdb._structured_indexes[SEE_VID_ID][0]+1,
                                         len(imdb._structured_indexes[SEE_VID_ID])))
    data = list(data)
    data[0] = data[0].unsqueeze(0)
    data[1] = data[1].unsqueeze(0)
    data = tuple(data)
    #data[2] = data[2].unsqueeze(0)
    im_data.data.resize_(data[0].size()).copy_(data[0])
    im_info.data.resize_(data[1].size()).copy_(data[1])
    #gt_boxes.data.resize_(data[2].size()).copy_(data[2])
    #num_boxes.data.resize_(data[3].size()).copy_(data[3])

    input = im_data, im_info, template_weights, rois_tracking

    det_tic = time.time()
    siam_rois, siam_bbox_pred, siam_cls_prob, rois, rois_label, cls_prob, bbox_pred = RCNN(input)

    scores = None
    pred_boxes = None
    
    ###########################################
    # Get detection boxes.
    ###########################################
    if cfg.TEST.BBOX_REG:
        if rois is not None:
            boxes = rois.data[:, :, 1:5]
            pred_boxes = bbox_delta_to_pred_boxes(im_info, boxes, bbox_pred)
            
            pred_rois = pred_boxes[:,4:].view(-1,4)
            pred_rois = torch.cat((pred_boxes.new_zeros(pred_rois.size(0),1),pred_rois),dim=1).unsqueeze(0)
            bbox_pred, cls_prob, cls_score = RCNN.RCNN.base_feat_to_roi_pred(RCNN.RCNN.base_feat_for_roi, pred_rois, None)
            boxes = pred_rois.data[:, :, 1:5]
            pred_boxes = bbox_delta_to_pred_boxes(im_info, boxes, bbox_pred)
            scores = cls_prob.data
            
            pred_boxes /= data[1][0][2].item()
            scores = scores.squeeze()
            pred_boxes = pred_boxes.squeeze()
            ###TODO###
            # For debug only.
            if SEE_DETECTION:
                for j in xrange(1, imdb.num_classes):
                    inds = torch.nonzero(scores[:,j]>0).view(-1)
                    # if there is det
                    if inds.numel() > 0:
                        cls_scores = scores[:,j][inds]
                        _, order = torch.sort(cls_scores, 0, True)
                        if args.class_agnostic:
                            cls_boxes = pred_boxes[inds, :]
                        else:
                            cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                        cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                        #cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                        cls_dets = cls_dets[order]
                        ######### nms for each cls here ########
                        keep = nms(cls_dets, cfg.TEST.NMS)
                        cls_dets = cls_dets[keep.view(-1).long()]
                        all_boxes_detection[j][i] = cls_dets.cpu().numpy()
                    else:
                        all_boxes_detection[j][i] = empty_array
            
        if siam_bbox_pred is not None:
            siam_scores = siam_cls_prob.data
            siam_boxes = siam_rois.data[:, 1:5]
            pred_siam_bbox = bbox_delta_to_pred_boxes(im_info, siam_boxes.unsqueeze(0), siam_bbox_pred.unsqueeze(0))                                                            
            pred_siam_bbox /= data[1][0][2].item()
            pred_siam_bbox = pred_siam_bbox.squeeze(0)
            # concatenate siambox and detbox.
            if MODE==0:
                if rois is not None:
                    pred_boxes = torch.cat((pred_boxes, pred_siam_bbox), 0)
                    scores = torch.cat((scores, siam_scores), 0)
                else:
                    pred_boxes = pred_siam_bbox
                    scores = siam_scores
            #####################
            if MODE==2:
                #####################
                #pred_boxes = pred_siam_bbox
                #####################
                pred_boxes = siam_boxes.repeat(1,siam_boxes.size(1)*31)
                pred_boxes = pred_boxes/im_info[0][-1]
                scores = siam_scores
            ###TODO###
            # For debug only.
            if SEE_TRACKING:
                for j in xrange(1, imdb.num_classes):
                    inds = torch.nonzero(siam_scores[:,j]>0).view(-1)
                    # if there is det
                    if inds.numel() > 0:
                        cls_scores = siam_scores[:,j][inds]
                        #_, order = torch.sort(cls_scores, 0, True)
                        if args.class_agnostic:
                            cls_boxes = pred_siam_bbox[inds, :]
                        else:
                            cls_boxes = pred_siam_bbox[inds][:, j * 4:(j + 1) * 4]
                        cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                        # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                        #cls_dets = cls_dets[order]
                        ######### nms for each cls here ########
                        #keep = nms(cls_dets, cfg.TEST.NMS)
                        #cls_dets = cls_dets[keep.view(-1).long()]
                        all_boxes_tracking[j][i] = cls_dets.cpu().numpy()
                    else:
                        all_boxes_tracking[j][i] = empty_array
    else:
        raise ValueError('Error. Should set cfg.TEST.BBOX_REG to True.')

    det_toc = time.time()
    detect_time = det_toc - det_tic
    misc_tic = time.time()
    ###########################################
    # NMS for detection and save to all boxes.
    ###########################################
    if scores is not None:
        for j in xrange(1, imdb.num_classes):
            inds = torch.nonzero(scores[:,j]>thresh).view(-1)
            # if there is det
            if inds.numel() > 0:
                cls_scores = scores[:,j][inds]
                all_scores = scores[inds]
                _, order = torch.sort(cls_scores, 0, True)
                if args.class_agnostic:
                    cls_boxes = pred_boxes[inds, :]
                else:
                    cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]
                cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
                # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
                cls_dets = cls_dets[order]
                all_scores = all_scores[order]
                ######### nms for each cls here ########
                keep = nms(cls_dets, cfg.TEST.NMS)
                cls_dets = cls_dets[keep.view(-1).long()]
                all_cls_scores = all_scores[keep.view(-1).long()]
                all_boxes[j][i] = cls_dets.cpu().numpy()
                all_boxes_scores[j][i] = all_cls_scores.cpu().numpy()
            else:
                all_boxes[j][i] = empty_array
                all_boxes_scores[j][i] = empty_array
    else:
        for j in xrange(1, imdb.num_classes):
            all_boxes[j][i] = empty_array
            all_boxes_scores[j][i] = empty_array

    # Limit to max_per_image detections *over all classes*
    if max_per_image > 0:
        image_scores = np.hstack([all_boxes[j][i][:, -1]
                                for j in xrange(1, imdb.num_classes)])
        if len(image_scores) > max_per_image:
            image_thresh = np.sort(image_scores)[-max_per_image]
            for j in xrange(1, imdb.num_classes):
                keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
                all_boxes[j][i] = all_boxes[j][i][keep, :]
                all_boxes_scores[j][i] = all_boxes_scores[j][i][keep, :]
                
    # nms between classes.
    if cfg.TEST.NMS_CROSS_CLASS > 0.:
        cls_boxes = [[]]+[all_boxes[j][i] for j in range(1, imdb.num_classes)]
        cls_boxes_scores = [[]]+[all_boxes_scores[j][i] for j in range(1, imdb.num_classes)]
        cls_box_list = [cls_boxes[j] for j in range(1, imdb.num_classes) if len(cls_boxes[j])>0]
        if len(cls_box_list)>0:
            all_dets = np.vstack(cls_box_list)
            all_dets_scores = np.vstack([cls_boxes_scores[j] for j in range(1, imdb.num_classes) if len(cls_boxes[j])>0])
            class_ids = np.vstack(
                [np.ones(shape=(len(cls_boxes[j]), 1))*j for j in range(1, imdb.num_classes) if len(cls_boxes[j])>0])
            _inds = np.argsort(-all_dets[:,-1])
            all_dets = all_dets[_inds,:]
            all_dets_scores = all_dets_scores[_inds,:]
            class_ids = class_ids[_inds,:]
            
            keep = nms(torch.tensor(all_dets.astype(np.float32)).cuda(), cfg.TEST.NMS_CROSS_CLASS).view(-1).long().cpu().numpy()
            all_dets = all_dets[keep, :]
            all_dets_scores = all_dets_scores[keep, :]
            class_ids = class_ids[keep, :]
            for j in range(1, imdb.num_classes):
                idx_j = np.where(class_ids==j)[0]
                all_boxes[j][i] = all_dets[idx_j, :]
                all_boxes_scores[j][i] = all_dets_scores[idx_j, :]
            
    ########
    # Get weights for the next iteration.
    ########
    # First, convert all_boxes to rois_tracking.#
    # TODO change threshold.
    rois_tracking = prepare_rois_tracking(im_info[0], all_boxes, all_boxes_scores, frame_id=i,
                                        class_num=imdb.num_classes, thresh=0.9)
    base_feat = RCNN.track_feat_trans_1.cuda()(RCNN.RCNN.Conv_feat_track)
    template_weights, rois_tracking = siam_weights_preparation(rois_tracking, base_feat)

    misc_toc = time.time()
    nms_time = misc_toc - misc_tic

    sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s   \r' \
      .format(i + 1, num_images, detect_time, nms_time))
    sys.stdout.flush()

#with open(det_file, 'wb') as f:
    #pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

#print('Evaluating detections')
#imdb.evaluate_detections(all_boxes, output_dir)

end = time.time()
print("test time: %0.4fs" % (end - start))

### Examine

In [None]:
print(imdb._structured_indexes[SEE_VID_ID][0],imdb._structured_indexes[SEE_VID_ID][-1])

In [None]:
import scipy.io as sio
devkit_path = 'data/imagenet/ILSVRC/devkit'
####VID object info####
synsets_video = sio.loadmat(os.path.join(devkit_path,'data','meta_vid.mat'))
_wnid = (0,)
_classes = ('__background__',)
for i in xrange(30):
    _classes = _classes + (synsets_video['synsets'][0][i][2][0],)
    _wnid = _wnid + (synsets_video['synsets'][0][i][1][0],)

_wnid_to_ind = dict(zip(_wnid, xrange(31)))
_class_to_ind = dict(zip(_classes, xrange(31)))
#######################
def read_annotation(dataPath):
    filename = dataPath.replace('Data','Annotations').replace('JPEG','xml')
    assert os.path.exists(filename),'%s'%(filename)
    # print 'Loading: {}'.format(filename)
    def get_data_from_tag(node, tag):
        return node.getElementsByTagName(tag)[0].childNodes[0].data

    with open(filename) as f:
        data = minidom.parseString(f.read())

    objs = data.getElementsByTagName('object')
    num_objs = len(objs)
    
    boxes = np.zeros((num_objs, 4), dtype=np.int32)
    gt_classes = np.zeros(num_objs,dtype=np.int32)
    # Load object bounding boxes into a data frame.
    for ix, obj in enumerate(objs):
        x1 = float(get_data_from_tag(obj, 'xmin'))
        y1 = float(get_data_from_tag(obj, 'ymin'))
        x2 = float(get_data_from_tag(obj, 'xmax'))
        y2 = float(get_data_from_tag(obj, 'ymax'))
        cls = _wnid_to_ind[
                str(get_data_from_tag(obj, "name")).lower().strip()]
        boxes[ix, :] = [x1, y1, x2, y2]
        gt_classes[ix] = cls
    return boxes, gt_classes

In [None]:
def showBoxResult(img_id, all_boxes, imdb, show_class=True, threshold=0.3, dpi = 200,all_boxes_tracking=None):
    imIdx = imdb.image_index[img_id]+'.JPEG'
    assert os.path.exists(imIdx), imIdx+' does not exist.'
    #print(imIdx)
    gt_boxes, gt_classes = read_annotation(imIdx)
    im2show = np.array(PIL.Image.open(imIdx))
    fig = plt.figure(frameon=False)
    fig.set_size_inches(im2show.shape[1] / dpi, im2show.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im2show)
    
    #draw gt first
    for j in range(len(gt_classes)):
        ax.add_patch(
        plt.Rectangle((gt_boxes[j][0], gt_boxes[j][1]),
                      gt_boxes[j][2] - gt_boxes[j][0],
                      gt_boxes[j][3] - gt_boxes[j][1],
                      fill=False, edgecolor='g',
                      linewidth=2.0, alpha=0.8))
        if show_class:
            ax.text(gt_boxes[j][0], gt_boxes[j][1] - 2,
            '%s %f'%(imdb.classes[gt_classes[j]], 1.0),
            fontsize=10,
            #family='serif',
            bbox=dict(facecolor='g', alpha=0.5, pad=0, edgecolor='none'),
                color='white')
 
    #draw dets
    for j in range(len(imdb.classes)):
        if isinstance(all_boxes[j][img_id], np.ndarray):
            for bbox in all_boxes[j][img_id]:
                if bbox[-1]>threshold:
                    ax.add_patch(
                    plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1],
                              fill=False, edgecolor='y',
                              linewidth=2.0, alpha=0.4))
                    if show_class:
                        ax.text(bbox[0], bbox[1] - 2,
                        '%s %f'%(imdb.classes[j], bbox[-1]),
                        fontsize=10,
                        #family='serif',
                        bbox=dict(
                            facecolor='y', alpha=0.5, pad=0, edgecolor='none'),
                            color='white')
            #im2show = vis_detections(im2show, imdb.classes[j], all_boxes[j][img_id], 0.3)
    #plt.show()
    if all_boxes_detection is not None:
        for j in range(len(imdb.classes)):
            if isinstance(all_boxes_detection[j][img_id], np.ndarray):
                for bbox in all_boxes_detection[j][img_id]:
                    if bbox[-1]>threshold:
                        ax.add_patch(
                        plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1],
                                  fill=False, edgecolor='m',
                                  linewidth=2.0, alpha=0.2))
                        if show_class:
                            ax.text(bbox[0], bbox[1] - 2,
                            '%s %f'%(imdb.classes[j], bbox[-1]),
                            fontsize=10,
                            #family='serif',
                            bbox=dict(
                                facecolor='m', alpha=0.2, pad=0, edgecolor='none'),
                                color='white')
    
    if all_boxes_tracking is not None:
        for j in range(len(imdb.classes)):
            if isinstance(all_boxes_tracking[j][img_id], np.ndarray):
                for bbox in all_boxes_tracking[j][img_id]:
                    if bbox[-1]>threshold:
                        ax.add_patch(
                        plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1],
                                  fill=False, edgecolor='w',
                                  linewidth=2.0, alpha=0.4))
                        if show_class:
                            ax.text(bbox[0], bbox[1] - 2,
                            '%s %f'%(imdb.classes[j], bbox[-1]),
                            fontsize=10,
                            #family='serif',
                            bbox=dict(
                                facecolor='w', alpha=0.5, pad=0, edgecolor='none'),
                                color='white')
    return fig

In [None]:
def makeVid(imdb,all_boxes,img_ids,vid_id,dname_prefix='output/tmp_vid',dir_suffix='',ext='png',all_boxes_tracking=None):
    name_ls = []
    dname = dname_prefix+'_%03d'%(vid_id)+dir_suffix
    if not os.path.exists(dname):
        os.mkdir(dname)
        assert os.path.exists(dname)
    for i in img_ids:
        fname = os.path.join(dname,'%08d'%(i)+'.'+ext)
        fig = showBoxResult(i,all_boxes,imdb,all_boxes_tracking=all_boxes_tracking)
        fig.savefig(fname)
        plt.close(fig)
        name_ls.append(fname)
    return name_ls

In [None]:
img_ids = imdb._structured_indexes[SEE_VID_ID]
dir_suffix = ''
if MODE==1:
    dir_suffix='_det'
elif MODE==2:
    dir_suffix='_tra'
name_ls = makeVid(imdb,all_boxes,img_ids,SEE_VID_ID,dname_prefix='output/tmp_vid',dir_suffix=dir_suffix,ext='pdf',all_boxes_tracking=all_boxes_tracking)

### Merge all images.

In [None]:
from PyPDF2 import PdfFileMerger

pdfs = name_ls
merger = PdfFileMerger()

for pdf in pdfs:
    merger.append(open(pdf, 'rb'))

with open(os.path.join(os.path.dirname(pdfs[0]),'result.pdf'), 'wb') as fout:
    merger.write(fout)

In [None]:
print(os.path.join(os.path.dirname(pdfs[0]),'result.pdf'))