forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
video_input_op.cc
93 lines (84 loc) · 3.41 KB
/
video_input_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <caffe2/video/video_input_op.h>
namespace caffe2 {
REGISTER_CPU_OPERATOR(VideoInput, VideoInputOp<CPUContext>);
OPERATOR_SCHEMA(VideoInput)
.NumInputs(0, 1)
.NumOutputs(2, 5)
.TensorInferenceFunction(
[](const OperatorDef& def,
const vector<TensorShape>& /* unused */ /*in*/) {
ArgumentHelper helper(def);
int batch_size = helper.GetSingleArgument<int>("batch_size", 0);
int clip_per_video =
helper.GetSingleArgument<int>("clip_per_video", 1);
int crop_size = helper.GetSingleArgument<int>("crop_size", -1);
int length_rgb = helper.GetSingleArgument<int>("length_rgb", 0);
int channels_rgb = helper.GetSingleArgument<int>("channels_rgb", 3);
int length_of = helper.GetSingleArgument<int>("length_of", 0);
int channels_of = helper.GetSingleArgument<int>("channels_of", 2);
// get the flags
bool get_rgb = helper.GetSingleArgument<bool>("get_rgb", true);
bool get_optical_flow =
helper.GetSingleArgument<bool>("get_optical_flow", false);
bool do_multi_label =
helper.GetSingleArgument<bool>("do_multi_label", false);
bool get_video_id =
helper.GetSingleArgument<bool>("get_video_id", false);
bool get_start_frame =
helper.GetSingleArgument<bool>("get_start_frame", false);
// get starting positions if available
vector<int> clip_start_positions =
helper.GetRepeatedArgument<int>("clip_start_positions", {});
// In case clip_start_positions are given, set the clip_per_video arg
if (clip_start_positions.size() > 0) {
clip_per_video = clip_start_positions.size();
}
int output_size = 1;
if (get_rgb) {
output_size++;
}
if (get_optical_flow) {
output_size++;
}
if (get_video_id) {
output_size++;
}
if (get_start_frame) {
output_size++;
}
int index = 0;
vector<TensorShape> out(output_size);
CHECK_GT(crop_size, 0);
batch_size *= clip_per_video;
if (get_rgb) {
out[index++] = CreateTensorShape(
vector<int>{
batch_size, channels_rgb, length_rgb, crop_size, crop_size},
TensorProto::FLOAT);
}
if (get_optical_flow) {
out[index++] = CreateTensorShape(
vector<int>{
batch_size, channels_of, length_of, crop_size, crop_size},
TensorProto::FLOAT);
}
if (!do_multi_label) {
out[index++] = CreateTensorShape(
vector<int>{1, batch_size}, TensorProto::INT32);
} else {
int num_of_class = helper.GetSingleArgument<int>("num_of_class", 0);
out[index++] = CreateTensorShape(
vector<int>{batch_size, num_of_class}, TensorProto::INT32);
}
if (get_video_id) {
out[index++] = CreateTensorShape(
vector<int64_t>{1, batch_size}, TensorProto::INT64);
}
if (get_start_frame) {
out[index] = CreateTensorShape(
vector<int>{1, batch_size}, TensorProto::INT32);
}
return out;
});
NO_GRADIENT(VideoInput);
} // namespace caffe2