In [2]:
import numpy as np
import argparse
import os
import pickle
import cv2
from train_utils import *
from models import *
import time

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from models import *
from dataset import WiderDataset
from train_utils import *

from detect_utils import detect_faces

from config import args

In [2]:
def test_net(test_mode="rnet", thresh=[0.6, 0.7, 0.7], min_face_size=12):

    cuda = True

    detectors = [None, None, None]

    # load pnet model
    pnet = PNet()
    if cuda:
        pnet = pnet.cuda()
    checkpoint = torch.load("pnet.pth.tar")
    pnet.load_state_dict(checkpoint['state_dict'])
    detectors[0] = pnet
    
    # load rnet model
    if test_mode in ["rnet", "onet"]:
        rnet = PNet()
        if cuda:
            rnet = rnet.cuda()
        checkpoint = torch.load("model_best_rnet.pth.tar")
        rnet.load_state_dict(checkpoint['state_dict'])
        detectors[1] = rnet

    # load onet model
    if test_mode == "onet":
        onet = ONet()
        if cuda:
            onet = onet.cuda()
        checkpoint = torch.load("model_best_onet.pth.tar")
        onet.load_state_dict(checkpoint['state_dict'])
        detectors[2] = onet
        
        
    dataset = WiderDataset(
        os.path.join(args.data, "anno.txt"),
        os.path.join(args.data, "WIDER_train/images"))
    
    
    detections = []
    for i, (input, _) in enumerate(dataset):
        boxes = detect_faces(detectors, input)
        detections.append(boxes)
        #break
        
        if i%100 == 0:
            print("{} images done".format(i))
            
    print("saving detections")
    
    # dummy detection for debug:
    #detections = [np.array([[10, 30, 20, 50, 1.0]]) for _ in range(len(dataset))]
        
    
    if test_mode == "pnet":
        net = "rnet"
    elif test_mode == "rnet":
        net = "onet"

    save_path = os.path.join(args.data, net)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    save_file = os.path.join(save_path, "detections.pkl")
    with open(save_file, 'wb') as f:
        pickle.dump(detections, f, pickle.HIGHEST_PROTOCOL)

In [72]:
# run model on WIDER dataset and pickle the detections

test_net("pnet")

saving detections


In [5]:
# unpickle detections and generate training data for rnet/onet

net_mode = "rnet"
save_hard_example(net_mode, args.data)

processing 12880 images in total
0 images done
100 images done
200 images done


KeyboardInterrupt: 

In [4]:
from box_utils import *

def save_hard_example(net, basepath):

    image_dir = os.path.join(basepath, "WIDER_train/images")
    neg_save_dir = os.path.join(basepath, "24/negative")
    pos_save_dir = os.path.join(basepath, "24/positive")
    part_save_dir = os.path.join(basepath, "/24/part")
    anno_file = os.path.join(basepath, 'anno.txt')
    save_path = os.path.join(basepath, net)
    
    

    # load ground truth from annotation file
    # format of each line: image/path [x1,y1,x2,y2] for each gt_box in this image
    with open(anno_file, 'r') as f:
        annotations = f.readlines()

    if net == "rnet":
        image_size = 24
    if net == "onet":
        image_size = 48

    im_idx_list = list()
    gt_boxes_list = list()
    num_of_images = len(annotations)
    print("processing %d images in total"%num_of_images)

    for annotation in annotations:
        annotation = annotation.strip().split(' ')
        im_idx = annotation[0]

        boxes = list(map(float, annotation[1:]))
        boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
        im_idx_list.append(im_idx)
        gt_boxes_list.append(boxes)

    f1 = open(os.path.join(save_path, 'pos_%d.txt'%image_size), 'w')
    f2 = open(os.path.join(save_path, 'neg_%d.txt'%image_size), 'w')
    f3 = open(os.path.join(save_path, 'part_%d.txt'%image_size), 'w')

    det_boxes = pickle.load(open(os.path.join(save_path, 'detections.pkl'), 'rb'))
    assert len(det_boxes) == num_of_images, "incorrect amount of detections for ground truths"
    
    # index of neg, pos and part face, used as their image names
    n_idx = 0
    p_idx = 0
    d_idx = 0
    image_done = 0
    for im_idx, dets, gts in zip(im_idx_list, det_boxes, gt_boxes_list):
        if image_done % 100 == 0:
            print("%d images done"%image_done)
        image_done += 1

        if dets.shape[0]==0:
            continue
        img = cv2.imread(os.path.join(image_dir, im_idx))
        dets = convert_to_square(dets)
        dets[:, 0:4] = np.round(dets[:, 0:4])

        for box in dets:
            x_left, y_top, x_right, y_bottom, _ = box.astype(int)
            width = x_right - x_left + 1
            height = y_bottom - y_top + 1

            # ignore box that is too small or beyond image border
            if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1:
                continue

            # compute intersection over union(IoU) between current box and all gt boxes
            Iou = IoU(box, gts)
            cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :]
            resized_im = cv2.resize(cropped_im, (image_size, image_size),
                                    interpolation=cv2.INTER_LINEAR)

            # save negative images and write label
            if np.max(Iou) < 0.3:
                # Iou with all gts must below 0.3
                save_file = os.path.join(neg_save_dir, "%s.jpg"%n_idx)
                f2.write("%s/negative/%s"%(image_size, n_idx) + ' 0\n')
                cv2.imwrite(save_file, resized_im)
                n_idx += 1
            else:
                # find gt_box with the highest iou
                idx = np.argmax(Iou)
                assigned_gt = gts[idx]
                x1, y1, x2, y2 = assigned_gt

                # compute bbox reg label
                offset_x1 = (x1 - x_left) / float(width)
                offset_y1 = (y1 - y_top) / float(height)
                offset_x2 = (x2 - x_right) / float(width)
                offset_y2 = (y2 - y_bottom ) / float(height)

                # save positive and part-face images and write labels
                if np.max(Iou) >= 0.65:
                    save_file = os.path.join(pos_save_dir, "%s.jpg"%p_idx)
                    f1.write("%s/positive/%s"%(image_size, p_idx) + ' 1 %.2f %.2f %.2f %.2f\n'%(offset_x1, offset_y1, offset_x2, offset_y2))
                    cv2.imwrite(save_file, resized_im)
                    p_idx += 1

                elif np.max(Iou) >= 0.4:
                    save_file = os.path.join(part_save_dir, "%s.jpg"%d_idx)
                    f3.write("%s/part/%s"%(image_size, d_idx) + ' -1 %.2f %.2f %.2f %.2f\n'%(offset_x1, offset_y1, offset_x2, offset_y2))
                    cv2.imwrite(save_file, resized_im)
                    d_idx += 1
    f1.close()
    f2.close()
    f3.close()