-
Notifications
You must be signed in to change notification settings - Fork 18.7k
/
classifier.py
94 lines (79 loc) · 3.24 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python
"""
Classifier is an image classifier specialization of Net.
"""
import numpy as np
import caffe
class Classifier(caffe.Net):
"""
Classifier extends Net for image class prediction
by scaling, center cropping, or oversampling.
"""
def __init__(self, model_file, pretrained_file, image_dims=None,
gpu=False, mean=None, input_scale=None, raw_scale=None,
channel_swap=None):
"""
Take
image_dims: dimensions to scale input for cropping/sampling.
Default is to scale to net input size for whole-image crop.
gpu, mean, input_scale, raw_scale, channel_swap: params for
preprocessing options.
"""
caffe.Net.__init__(self, model_file, pretrained_file)
self.set_phase_test()
if gpu:
self.set_mode_gpu()
else:
self.set_mode_cpu()
if mean is not None:
self.set_mean(self.inputs[0], mean)
if input_scale is not None:
self.set_input_scale(self.inputs[0], input_scale)
if raw_scale is not None:
self.set_raw_scale(self.inputs[0], raw_scale)
if channel_swap is not None:
self.set_channel_swap(self.inputs[0], channel_swap)
self.crop_dims = np.array(self.blobs[self.inputs[0]].data.shape[2:])
if not image_dims:
image_dims = self.crop_dims
self.image_dims = image_dims
def predict(self, inputs, oversample=True):
"""
Predict classification probabilities of inputs.
Take
inputs: iterable of (H x W x K) input ndarrays.
oversample: average predictions across center, corners, and mirrors
when True (default). Center-only prediction when False.
Give
predictions: (N x C) ndarray of class probabilities
for N images and C classes.
"""
# Scale to standardize input dimensions.
input_ = np.zeros((len(inputs),
self.image_dims[0], self.image_dims[1], inputs[0].shape[2]),
dtype=np.float32)
for ix, in_ in enumerate(inputs):
input_[ix] = caffe.io.resize_image(in_, self.image_dims)
if oversample:
# Generate center, corner, and mirrored crops.
input_ = caffe.io.oversample(input_, self.crop_dims)
else:
# Take center crop.
center = np.array(self.image_dims) / 2.0
crop = np.tile(center, (1, 2))[0] + np.concatenate([
-self.crop_dims / 2.0,
self.crop_dims / 2.0
])
input_ = input_[:, crop[0]:crop[2], crop[1]:crop[3], :]
# Classify
caffe_in = np.zeros(np.array(input_.shape)[[0,3,1,2]],
dtype=np.float32)
for ix, in_ in enumerate(input_):
caffe_in[ix] = self.preprocess(self.inputs[0], in_)
out = self.forward_all(**{self.inputs[0]: caffe_in})
predictions = out[self.outputs[0]].squeeze(axis=(2,3))
# For oversampling, average predictions across crops.
if oversample:
predictions = predictions.reshape((len(predictions) / 10, 10, -1))
predictions = predictions.mean(1)
return predictions