-
Notifications
You must be signed in to change notification settings - Fork 242
/
20231009_api_design_for_fractional_max_pool.md
1441 lines (1149 loc) · 54.9 KB
/
20231009_api_design_for_fractional_max_pool.md
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# FractionalMaxPool2d / FractionalMaxPool3d API 设计文档
| API 名称 | FractionalMaxPool2d / FractionalMaxPool3d |
| - | - |
| 提交作者 | megemini(柳顺) |
| 提交时间 | 2024-01-12 |
| 版本号 | V2.1 |
| 依赖飞桨版本 | develop |
| 文件名 | 20231009_api_design_for_fractional_max_pool.md |
#### 修订记录
v2.0: 将实现方式由 python 改为 c++
v2.1: 修改接口签名
# 一、概述
## 1、相关背景
[《Fractional Max-Pooling》](https://arxiv.org/abs/1412.6071) 这篇文章介绍了一种 `fractional` 的池化方法,区别与传统的池化方法,如 `max-pooling`,`《Fractional Max-Pooling》` 的池化因子可以在 `1 < alpha < 2` 之间,也就是说,每次池化操作可以将输入缩小诸如 `sqrt(2)` 倍,而不是简单的 `2` 倍。比如,可以将输入尺寸为 `25` 缩小为输出 `18`,此时 `alpha = 25/18 = 1.39`。
文章中提到,这种池化方法可以防止传统池化方式快速缩小输入尺寸,从而影响性能的问题。可以介由网络对于更多不同尺寸输入的识别,以提升模型整体的识别能力。
飞桨目前实现了诸如 `max-pooling`、`avg-pooling` 等方法,但没有实现 `fractional max pooling`,此次实现 `fractional max pool2d / fractional max pool3d` 以提升飞桨 API 的丰富程度。
## 2、功能目标
在一个由多个通道组成的输入信号上施加分数最大池化。分数最大池化请参考论文 [《Fractional Max-Pooling》](https://arxiv.org/abs/1412.6071)
调用形式
- `paddle.nn.FractionalMaxPool2d`
- `paddle.nn.FractionalMaxPool3d`
- `paddle.nn.functional.fractional_max_pool2d`
- `paddle.nn.functional.fractional_max_pool3d`
## 3、意义
为 `Paddle` 增加 `Fractional Max-Pooling` 操作,丰富 `Paddle` 中池化操作相关的 API。
# 二、飞桨现状
飞桨目前已经提供了诸多的池化方法,如:`max_poolNd`、`avg_poolNd` 等,但尚未提供 `fractional_max_pool` 方法,底层也没有相关算子的实现。
飞桨目前将池化操作相关函数放在 `python/paddle/nn/functional/pooling.py` 文件中,另外,在 `python/paddle/nn/layer/pooling.py` 中提供了构造网络需要的模块。其中对应的 `layer` 层,均可通过调用 `functional` 相关函数实现。
由此,`paddle.nn.FractionalMaxPoolNd` 可以通过调用 `paddle.nn.functional.fractional_max_poolNd` 实现。
# 三、业内方案调研
## 算法逻辑
对比 `2*2 max pooling` (2MP) ,2MP 的采样序列为 `22222...`,如果将其中混杂 `1`,如 `1121122112...`,便可以生成 `1 < alpha = N_in/N_out < 2` 的池化结果。
因此,算法的关键是如何生成 `1121122112...` 类似的序列,以满足 `output_size` 或 `input_size * output_ratio`。
注:这里的 `1` 和 `2` 可以理解为 `kernel/pool size`,也就是每次池化的尺寸,或者是文章中的 `increments`,之所以是 `1`、`2`,前提是 `1 < alpha < 2`,也就是说,这是介于 `原尺寸` 与 `2*2 max pooling` 之间的池化操作。如果 `alpha > 2`,类似于 `3*3 max pooling`,这里的序列可以是任何大于零的整数。后续为简化谈论,假设 `1 < alpha < 2`。
文章中介绍了两种方式,`真` 随机(`random`)与 `伪` 随机(`pseudo random`)。
- `真` 随机(`random`)
随机生成 `1` 和 `2` 的序列,只要满足:
- 序列长度为 `output_size`
- 序列累加和为 `input_size`
- `伪` 随机(`pseudo random`)
这里生成的累加序列,需要满足:
`a = ceil(alpha(i+u)), 1 < alpha = N_in/N_out < 2, 0 < u < 1, i = 0,1,2...N_out`
长度为 `output_size + 1`,`u` 为随机数,可以利用随机种子固定住。由此生成序列:
`diff = a[i+1] - a[i]`
生成随机序列后,便可以利用 `max` 操作,在每个池化窗口取最大值,由此产生最后的输出。
## PyTorch
`PyTorch` 底层通过 c++ 实现 `fractional_max_pool2d / fractional_max_pool3d` 函数,并通过上层的 python 对外开放相应接口。
相应的,`FractionalMaxPool2d` 通过 `fractional_max_pool2d` 实现,`FractionalMaxPool3d` 通过 `fractional_max_pool3d` 实现。
相应文档:
- [FRACTIONALMAXPOOL2D](https://pytorch.org/docs/stable/generated/torch.nn.FractionalMaxPool2d.html#fractionalmaxpool2d)
- [FRACTIONALMAXPOOL3D](https://pytorch.org/docs/stable/generated/torch.nn.FractionalMaxPool3d.html#fractionalmaxpool3d)
- [TORCH.NN.FUNCTIONAL.FRACTIONAL_MAX_POOL2D](https://pytorch.org/docs/stable/generated/torch.nn.functional.fractional_max_pool2d.html#torch.nn.functional.fractional_max_pool2d)
- [TORCH.NN.FUNCTIONAL.FRACTIONAL_MAX_POOL3D](https://pytorch.org/docs/stable/generated/torch.nn.functional.fractional_max_pool3d.html#torch.nn.functional.fractional_max_pool3d)
相应接口为:
- `torch.nn.FractionalMaxPool2d(kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)`
- 文档描述
> Applies 2D fractional max pooling over an input signal composed of several input planes.
- 参数列表
> kernel_size – the size of the window to take a max over.
> output_size – the target output size
> output_ratio – If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1)
> return_indices – if True, will return the indices along with the outputs.
- 返回值
> output (Tensor)
- `torch.nn.FractionalMaxPool3d(kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)`
- 文档描述
> Applies 3D fractional max pooling over an input signal composed of several input planes.
- 参数列表
> kernel_size – the size of the window to take a max over.
> output_size – the target output size
> output_ratio – If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1)
> return_indices – if True, will return the indices along with the outputs.
- 返回值
> output (Tensor)
- `torch.nn.functional.fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)`
- 文档描述
> Applies 2D fractional max pooling over an input signal composed of several input planes.
- 参数列表
> kernel_size – the size of the window to take a max over.
> output_size – the target output size
> output_ratio – If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1)
> return_indices – if True, will return the indices along with the outputs.
- 返回值
> output (Tensor)
- `torch.nn.functional.fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)`
- 文档描述
> Applies 3D fractional max pooling over an input signal composed of several input planes.
- 参数列表
> kernel_size – the size of the window to take a max over.
> output_size – the target output size
> output_ratio – If one wants to have an output size as a ratio of the input size, this option can be given. This has to be a number or tuple in the range (0, 1)
> return_indices – if True, will return the indices along with the outputs.
- 返回值
> output (Tensor)
实现逻辑:
由于 `fractional_max_pool2d` 与 `fractional_max_pool3d` 最大的区别是维度,其他逻辑基本相同,所以,后续以 `fractional_max_pool2d` 为主要分析对象。
相关源代码涉及文件:
- `torch/nn/functional.py` *
- `torch/csrc/api/include/torch/nn/options/pooling.h`
- `torch/csrc/api/include/torch/nn/functional/pooling.h` *
- `torch/csrc/api/include/torch/nn/modules/pooling.h`
- `torch/csrc/api/src/nn/modules/pooling.cpp`
- `aten/src/ATen/native/FractionalMaxPooling.h` *
- `aten/src/ATen/native/FractionalMaxPool2d.cpp` *
这里只分析上述带有 `*` 的主要源文件。
- `torch/nn/functional.py`
这里对 `fractional_max_pool2d` 开放 API:
``` python
def fractional_max_pool2d_with_indices(
input: Tensor, kernel_size: BroadcastingList2[int],
output_size: Optional[BroadcastingList2[int]] = None,
output_ratio: Optional[BroadcastingList2[float]] = None,
return_indices: bool = False,
_random_samples: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
r"""
fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
...
"""
if has_torch_function_variadic(input, _random_samples):
return handle_torch_function(
fractional_max_pool2d_with_indices,
(input, _random_samples),
input,
kernel_size,
output_size=output_size,
output_ratio=output_ratio,
return_indices=return_indices,
_random_samples=_random_samples,
)
if output_size is None and output_ratio is None:
raise ValueError("fractional_max_pool2d requires specifying either " "an output_size or an output_ratio")
if output_size is None:
assert output_ratio is not None
if len(output_ratio) > 2:
raise ValueError("fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints.")
_output_ratio = _pair(output_ratio)
output_size = [int(input.size(-2) * _output_ratio[0]), int(input.size(-1) * _output_ratio[1])]
if _random_samples is None:
n_batch = 1 if input.dim() == 3 else input.size(0)
_random_samples = torch.rand(n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device)
return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
def _fractional_max_pool2d(
input: Tensor, kernel_size: BroadcastingList2[int],
output_size: Optional[BroadcastingList2[int]] = None,
output_ratio: Optional[BroadcastingList2[float]] = None,
return_indices: bool = False,
_random_samples: Optional[Tensor] = None
) -> Tensor:
if has_torch_function_variadic(input, _random_samples):
return handle_torch_function(
fractional_max_pool2d,
(input, _random_samples),
input,
kernel_size,
output_size=output_size,
output_ratio=output_ratio,
return_indices=return_indices,
_random_samples=_random_samples,
)
return fractional_max_pool2d_with_indices(
input, kernel_size, output_size, output_ratio, return_indices, _random_samples
)[0]
fractional_max_pool2d = boolean_dispatch(
arg_name="return_indices",
arg_index=4,
default=False,
if_true=fractional_max_pool2d_with_indices,
if_false=_fractional_max_pool2d,
module_name=__name__,
func_name="fractional_max_pool2d",
)
```
这里根据是否需要 `indices` 对接口进行分发,最终都是调用 `fractional_max_pool2d_with_indices`。
- `torch/csrc/api/include/torch/nn/functional/pooling.h`
上面的接口会调用这里对应的 c++ 实现:
``` cpp
namespace detail {
inline std::tuple<Tensor, Tensor> fractional_max_pool2d_with_indices(
const Tensor& input,
const ExpandingArray<2>& kernel_size,
const c10::optional<ExpandingArray<2>>& output_size,
const c10::optional<ExpandingArray<2, double>>& output_ratio,
const Tensor& _random_samples) {
if (output_size == c10::nullopt && output_ratio == c10::nullopt) {
TORCH_CHECK(
false,
"fractional_max_pool2d requires specifying either ",
"an output_size or an output_ratio");
}
c10::optional<ExpandingArray<2>> output_size_ = output_size;
if (output_size_ == c10::nullopt) {
TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt);
output_size_ = {
(int64_t)(static_cast<double>(input.size(-2)) * (*output_ratio.value())[0]),
(int64_t)(static_cast<double>(input.size(-1)) * (*output_ratio.value())[1])};
}
Tensor _random_samples_ = _random_samples;
if (!_random_samples_.defined()) {
auto n_batch = input.dim() == 3 ? 1 : input.size(0);
_random_samples_ = torch::rand(
{n_batch, input.size(-3), 2},
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
}
return torch::fractional_max_pool2d(
input, kernel_size, *output_size_, _random_samples_);
}
} // namespace detail
```
这里是 `fractional_max_pool2d` 主要入口,主要做了以下几处处理:
- 如果没有 `output_size`,根据 `output_ratio` 生成 `output_size`
- 如果没有 `_random_samples`,根据输入的维度生成随机序列
- 调用主要方法 `torch::fractional_max_pool2d(input, kernel_size, *output_size_, _random_samples_);}`
- `aten/src/ATen/native/FractionalMaxPool2d.cpp`
这里实现了具体的逻辑:
``` cpp
template <typename scalar_t>
static void fractional_max_pool2d_out_single_batch_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int numPlanes,
int inputW, int inputH,
int outputW, int outputH,
int poolSizeW, int poolSizeH) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (const auto plane : c10::irange(start, end)) {
/* each plane contains 2 random samples, one for W and one for H */
scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
/* Generate interval sequence */
auto sequenceW = generate_intervals<scalar_t>(
randomSamplesForPlane[0], inputW, outputW, poolSizeW);
auto sequenceH = generate_intervals<scalar_t>(
randomSamplesForPlane[1], inputH, outputH, poolSizeH);
/* loop over output */
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int h, w;
scalar_t* inputForPlane = input + plane * inputW * inputH;
scalar_t* outputForPlane = output + plane * outputW * outputH;
int64_t* indicesForPlane = indices + plane * outputW * outputH;
for (h = 0; h < outputH; ++h) {
int inputHStart = sequenceH[h];
for (w = 0; w < outputW; ++w) {
int inputWStart = sequenceW[w];
int h2 = inputHStart, w2 = inputWStart;
scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
int64_t maxIndex = h2 * inputW + w2;
for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
AT_ASSERT(h2 >= 0 && h2 < inputH);
AT_ASSERT(w2 >= 0 && w2 < inputW);
int planeIndex = h2 * inputW + w2;
scalar_t val = inputForPlane[planeIndex];
if (val > maxVal || std::isnan(val)) {
maxVal = val;
maxIndex = planeIndex;
}
}
}
outputForPlane[h * outputW + w] = maxVal;
indicesForPlane[h * outputW + w] = maxIndex;
}
}
}
});
}
```
此文件实现了 `fractional_max_pool2d` 的主要逻辑,上面只摘抄了最关键的代码。
主要逻辑为:
- 生成采样的序列
- 获取序列中的每个 pool 中的最大值
其中,生成采样序列的逻辑在 `aten/src/ATen/native/FractionalMaxPooling.h`:
``` cpp
template<typename scalar_t>
static inline std::vector<int> generate_intervals(
scalar_t sample,
int64_t inputSize,
int64_t outputSize,
int64_t poolSize) {
std::vector<int> sequence(outputSize);
if (outputSize > 1) {
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
static_cast<scalar_t>(outputSize - 1);
for (const auto i : c10::irange(outputSize - 1)) {
sequence[i] =
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
}
if (outputSize > 0) {
sequence[outputSize - 1] = inputSize - poolSize;
}
return sequence;
}
```
从上面的源代码分析可以看到,`PyTorch` 对于 `fractional_max_pool` 只实现了 `pseudo random` 的方式,而没有 `random` 的方式。
## TensorFlow
`TensorFlow` 实现了 `tf.nn.fractional_max_pool` 函数,对应 `PyTorch` 的函数为 `fractional_max_pool2d`。
相应的,实现了 `tf.raw_ops.FractionalMaxPool` ,对应 `PyTorch` 的 `FractionalMaxPool2d`。
`TensorFlow` 并没有 `3D` 相关的实现。
`3D` 相对 `2D` ,多了一个 `depth` 或者 `time` 等类似的维度。
相应文档:
- [tf.raw_ops.FractionalMaxPool](https://tensorflow.google.cn/api_docs/python/tf/raw_ops/FractionalMaxPool?hl=en)
- [tf.nn.fractional_max_pool](https://tensorflow.google.cn/api_docs/python/tf/nn/fractional_max_pool?hl=en)
相应接口为:
- `tf.raw_ops.FractionalMaxPool`
- 文档描述
> Performs fractional max pooling on the input.
- 参数列表
> value – A Tensor. 4-D with shape [batch, height, width, channels].
> pooling_ratio – An int or list of ints that has length 1, 2 or 4.
> pseudo_random – An optional bool. Defaults to False. When set to True, generates the pooling sequence in a pseudorandom fashion, otherwise, in a random fashion.
> overlapping – An optional bool. Defaults to False. When set to True, it means when pooling, the values at the boundary of adjacent pooling cells are used by both cells.
> deterministic – An optional bool. Defaults to False. When set to True, a fixed pooling region will be used when iterating over a FractionalMaxPool node in the computation graph.
> seed – An optional int. Defaults to 0. If set to be non-zero, the random number generator is seeded by the given seed. Otherwise it is seeded by a random seed.
> seed2 – An optional int. Defaults to 0. An second seed to avoid seed collision.
> name – A name for the operation (optional).
- 返回值
> output (A tuple of Tensor objects)
- `tf.nn.fractional_max_pool`
- 文档描述
> Performs fractional max pooling on the input.
- 参数列表
> value – A Tensor. 4-D with shape [batch, height, width, channels].
> pooling_ratio – An int or list of ints that has length 1, 2 or 4.
> pseudo_random – An optional bool. Defaults to False. When set to True, generates the pooling sequence in a pseudorandom fashion, otherwise, in a random fashion.
> overlapping – An optional bool. Defaults to False. When set to True, it means when pooling, the values at the boundary of adjacent pooling cells are used by both cells.
> seed – An optional int. Defaults to 0. If set to be non-zero, the random number generator is seeded by the given seed. Otherwise it is seeded by a random seed.
> name – A name for the operation (optional).
- 返回值
> output (A tuple of Tensor objects)
实现逻辑:
相关源代码涉及文件:
- `tensorflow/python/ops/nn_ops.py` *
- `tensorflow/core/kernels/fractional_pool_common.h`
- `tensorflow/core/kernels/fractional_pool_common.cc` *
- `tensorflow/core/kernels/fractional_max_pool_op.cc` *
这里只分析上述带有 `*` 的主要源文件。
- `tensorflow/python/ops/nn_ops.py`
这里注册 python 接口:
``` python
@tf_export("nn.fractional_max_pool", v1=[])
@dispatch.add_dispatch_support
def fractional_max_pool_v2(value,
pooling_ratio,
pseudo_random=False,
overlapping=False,
seed=0,
name=None): # pylint: disable=redefined-builtin
if (isinstance(pooling_ratio, (list, tuple))):
if (pooling_ratio[0] != 1.0 or pooling_ratio[-1] != 1.0):
raise ValueError(
"`pooling_ratio` should have first and last elements with value 1.0. "
f"Received: pooling_ratio={pooling_ratio}")
for element in pooling_ratio:
if element < 1.0:
raise ValueError(
f"`pooling_ratio` elements should be >= 1.0. "
f"Received: pooling_ratio={pooling_ratio}")
elif (isinstance(pooling_ratio, (int, float))):
if pooling_ratio < 1.0:
raise ValueError(
"`pooling_ratio` should be >= 1.0. "
f"Received: pooling_ratio={pooling_ratio}")
else:
raise ValueError(
"`pooling_ratio` should be an int or a list of ints. "
f"Received: pooling_ratio={pooling_ratio}")
pooling_ratio = _get_sequence(pooling_ratio, 2, 3, "pooling_ratio")
if seed == 0:
if config.is_op_determinism_enabled():
raise ValueError(
f"tf.nn.fractional_max_pool requires a non-zero seed to be passed in "
f"when determinism is enabled, but got seed={seed}. Please pass in a "
f'non-zero seed, e.g. by passing "seed=1".')
return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
overlapping, deterministic=False,
seed=0, seed2=0, name=name)
else:
seed1, seed2 = random_seed.get_seed(seed)
return gen_nn_ops.fractional_max_pool(value, pooling_ratio, pseudo_random,
overlapping, deterministic=True,
seed=seed1, seed2=seed2, name=name)
```
可以看到,与 `PyTorch` 不同的是,`TensorFlow` 多了几个参数:
- `overlapping` 控制 pool 边界是否计算在内
- `pseudo_random` 是否是伪随机
- `seed` 随机种子
- `tensorflow/core/kernels/fractional_max_pool_op.cc`
这里实现了主要逻辑:
``` cpp
template <typename T>
class FractionalMaxPoolOp : public OpKernel {
public:
explicit FractionalMaxPoolOp(OpKernelConstruction* context)
: OpKernel(context) {
...
if (deterministic_) {
// If both seeds are not set when deterministic_ is true, force set seeds.
if ((seed_ == 0) && (seed2_ == 0)) {
seed_ = random::New64();
seed2_ = random::New64();
}
} else {
OP_REQUIRES(
context, (seed_ == 0) && (seed2_ == 0),
errors::InvalidArgument(
"Both seed and seed2 should be 0 if deterministic is false."));
}
}
void Compute(OpKernelContext* context) override {
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
EigenMatrixMap;
constexpr int tensor_in_and_out_dims = 4;
const Tensor& tensor_in = context->input(0);
std::vector<int> input_size(tensor_in_and_out_dims);
std::vector<int> output_size(tensor_in_and_out_dims);
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
input_size[i] = tensor_in.dim_size(i);
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
// This must match the same logic in the shape function in
// core/ops/nn_ops.cc.
output_size[i] =
static_cast<int>(std::floor(input_size[i] / pooling_ratio_[i]));
DCHECK_GT(output_size[i], 0);
}
// Generate pooling sequence.
std::vector<int64_t> height_cum_seq;
std::vector<int64_t> width_cum_seq;
GuardedPhiloxRandom generator;
generator.Init(seed_, seed2_);
height_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
&generator, pseudo_random_);
width_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
&generator, pseudo_random_);
// Prepare output.
Tensor* output_tensor = nullptr;
Tensor* output_height_seq_tensor = nullptr;
Tensor* output_width_seq_tensor = nullptr;
ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
input_size[2] * input_size[1] * input_size[0]);
EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
output_size[2] * output_size[1] * output_size[0]);
// Initializes the output tensor with MIN<T>.
output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
auto output_height_seq_flat = output_height_seq_tensor->flat<int64_t>();
auto output_width_seq_flat = output_width_seq_tensor->flat<int64_t>();
// Set output tensors.
for (int i = 0; i < height_cum_seq.size(); ++i) {
output_height_seq_flat(i) = height_cum_seq[i];
}
for (int i = 0; i < width_cum_seq.size(); ++i) {
output_width_seq_flat(i) = width_cum_seq[i];
}
// For both input and output,
// 0: batch
// 1: height / row
// 2: width / col
// 3: depth / channel
const int64_t height_max = input_size[1] - 1;
const int64_t width_max = input_size[2] - 1;
for (int64_t b = 0; b < input_size[0]; ++b) {
// height sequence.
for (int64_t hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
// height start and end.
const int64_t height_start = height_cum_seq[hs];
int64_t height_end =
overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1;
height_end = std::min(height_end, height_max);
// width sequence.
for (int64_t ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
const int64_t out_offset =
(b * output_size[1] + hs) * output_size[2] + ws;
// width start and end.
const int64_t width_start = width_cum_seq[ws];
int64_t width_end =
overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1;
width_end = std::min(width_end, width_max);
for (int64_t h = height_start; h <= height_end; ++h) {
for (int64_t w = width_start; w <= width_end; ++w) {
const int64_t in_offset =
(b * input_size[1] + h) * input_size[2] + w;
out_mat.col(out_offset) =
out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
}
}
}
}
}
}
};
```
其中序列生成的函数在: `tensorflow/core/kernels/fractional_pool_common.cc`
``` cpp
static std::vector<int64_t> GeneratePoolingSequencePseudoRandom(
int input_length, int output_length, GuardedPhiloxRandom* generator) {
std::vector<int64_t> cum_seq(output_length + 1, 0);
std::vector<int64_t> diff(output_length, 0);
double alpha = static_cast<double>(input_length) / output_length;
int k = input_length / output_length;
double u_max1 = (k + 2) / alpha - 1;
double u_max2 = (input_length + 1 - k) / alpha - (output_length - 1);
double max_u = std::min(u_max1, u_max2);
// Generate random number in parallel.
auto local_gen = generator->ReserveSamples32(2);
random::SimplePhilox random(&local_gen);
const double u = random.RandDouble() * max_u;
cum_seq[0] = 1;
cum_seq[output_length] = input_length + 1;
for (int i = 1; i < output_length; ++i) {
cum_seq[i] = static_cast<int>(ceil(alpha * (i + u)));
}
for (int i = 0; i < output_length; ++i) {
diff[i] = cum_seq[i + 1] - cum_seq[i];
}
return diff;
}
static std::vector<int64_t> GeneratePoolingSequenceRandom(
int input_length, int output_length, GuardedPhiloxRandom* generator) {
int k = input_length / output_length;
int num_random_spot = input_length % output_length;
std::vector<int64_t> diff(output_length, k);
for (int i = 0; i < num_random_spot; ++i) {
diff[i] += 1;
}
// Randomly shuffle this vector.
auto local_gen = generator->ReserveSamples32(diff.size());
random::SingleSampleAdapter<random::PhiloxRandom> single(&local_gen);
const auto uniform = [&single](uint32 n) { return single() % n; };
RandomShuffle(diff.begin(), diff.end(), uniform);
return diff;
}
std::vector<int64_t> GeneratePoolingSequence(int input_length,
int output_length,
GuardedPhiloxRandom* generator,
bool pseudo_random) {
std::vector<int64_t> diff;
// This is a case that regular pooling can handle, just return diff with
// each element input_length/output_length.
if (input_length % output_length == 0) {
diff = std::vector<int64_t>(output_length, input_length / output_length);
}
if (pseudo_random) {
diff = GeneratePoolingSequencePseudoRandom(input_length, output_length,
generator);
} else {
diff =
GeneratePoolingSequenceRandom(input_length, output_length, generator);
}
// Sanity check.
int k = input_length / output_length;
for (int i = 0; i < output_length; ++i) {
// k<= diff[i] <= k+1.
DCHECK_GE(diff[i], k);
DCHECK_LE(diff[i], k + 1);
}
// Return cumulative sequence.
std::vector<int64_t> cum_seq(output_length + 1, 0);
for (int i = 1; i < cum_seq.size(); ++i) {
cum_seq[i] = cum_seq[i - 1] + diff[i - 1];
}
return cum_seq;
}
```
这里根据 `pseudo_random` 标记为生成 `伪` 随机序列,或者 `真` 随机序列。
# 四、对比分析
抛开 `PyTorch` 与 `TensorFlow` 对于 API 的组织方式不同来说,两者:
相同:
- `PyTorch` 与 `TensorFlow` 都实现了 `fractional_max_pool` 函数。
- `PyTorch` 与 `TensorFlow` 都是通过底层 c++ 实现具体逻辑,并通过 python 公开 API。
不同:
- `PyTorch` 实现了 `2D` 与 `3D` 两种维度的函数,`TensorFlow` 只有 `2D` 这种维度(`channel` 不算在内)。
- `TensorFlow` 有 `真` 随机与 `伪` 随机两种序列生成方式,`PyTorch` 只有 `伪` 随机一种。
- `TensorFlow` 的实现更接近文章中的描述
这是 `PyTorch` 与 `TensorFlow` 最大的不同点。文章中的 `fractional` 根据 `N_in/N_out` 得出,也就是说,只需要这两个参数即可。
`PyTorch` 提供了 `kernel_size`、`output_size`、`output_ratio` 这三个参数,这三个参数都可以影响 `N_in/N_out`,这更像是传统池化的方法。
`TensorFlow` 只提供了 `pooling_ratio`,利用这个参数即可得到 `N_out`,而且提供了 `overlapping` 参数,利用这个参数可以影响 `kernel_size`。而且,由此可以看出,`TensorFlow` 实现的 `fractional max pooling` 更具有一般性,而 `adaptive max pooling` 则可以看作 `fractional max pooling` 的一种特例。
`PyTorch` 只利用随机序列作为 stride,而不是同时将其作为 kernel 进行池化,`TensorFlow` 将随机序列既作为 stride 同时也作为 kernel 进行池化,更符合论文中的描述方式,所以,这里以 `TensorFlow` 的方式进行实现。
- `fractional max pooling` : `a = ceiling(alpha(i+u)), 1 < alpha = N_in/N_out < 2, 0 < u < 1`
- `adaptive max pooling` : `a = ceiling(alpha(i+1)), 1 < alpha = N_in/N_out < 2`
另外,两者都有反向梯度的计算(由于不影响主要逻辑分析,且代码较多,上述代码分析没有具体列出)。
由于飞桨已经实现了 `AdaptiveMaxPool1D / AdaptiveMaxPool2D / AdaptiveMaxPool3D`,其签名为:
- `paddle.nn.AdaptiveMaxPool1D(output_size, return_mask=False, name=None)`
为了保持一致性,这里也只使用 `output_size` 一个必要参数,实现方法更接近文章以及 `TensorFlow`。
# 五、设计思路与实现方案
本方案共涉及三部分:
- 命名与参数设计 (python API) : `paddle.nn.functional.fractional_max_pool2d`, `paddle.nn.functional.fractional_max_pool3d`
- 底层 OP 设计
- python layer 实现 : `paddle.nn.FractionalMaxPool2d`, `paddle.nn.FractionalMaxPool3d`
由于 `fractional max pooling` 与 `adaptive max pooling` 接口特性较为相似,后续设计方案参考 `adaptive max pooling` 算子的实现方式,并单独实现 `fractional max pooling` 算子。
## 命名与参数设计 (python API)
涉及文件:`python/paddle/nn/functional/pooling.py`
添加 python 上层接口:
- `paddle.nn.functional.fractional_max_pool2d`
- `paddle.nn.FractionalMaxPool2d`
``` python
paddle.nn.functional.fractional_max_pool2d(
x:Tensor,
output_size:Union[int, list, tuple],
kernel_size:Optional[Union[int, list, tuple]]=None,
random_u:Optional[float]=None,
return_mask:bool=False,
name:str=None)
```
- 参数列表
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float32、float64、int32、int64。
> output_size (int|list|tuple) – 输出的尺寸。
> kernel_size (int|list|tuple, optional) – 核大小。
> random_u (float, optional) – 随机序列所需随机数。
> return_mask (bool, optional) – 是否返回最大值的索引。
> name (str, optional) – 操作名称。
- 返回值
> Tensor, return_mask=False
> Tensor and mask, return_mask=True
- `paddle.nn.functional.fractional_max_pool3d`
- `paddle.nn.FractionalMaxPool3d`
``` python
paddle.nn.functional.fractional_max_pool3d(
x:Tensor,
output_size:Union[int, list, tuple],
kernel_size:Optional[Union[int, list, tuple]]=None,
random_u:Optional[float]=None,
return_mask:bool=False,
name:str=None)
```
- 参数列表
> x (Tensor) – 输入的一个 Tensor。数据类型支持:float32、float64、int32、int64。
> output_size (int|list|tuple) – 输出的尺寸。
> kernel_size (int|list|tuple, optional) – 核大小。
> random_u (float, optional) – 随机序列所需随机数。
> return_mask (bool, optional) – 是否返回最大值的索引。
> name (str, optional) – 操作名称。
- 返回值
> Tensor, return_mask=False
> Tensor and mask, return_mask=True
这里重点分析 `paddle.nn.functional.fractional_max_poolNd` 接口的命名与参数设计,`paddle.nn.FractionalMaxPoolNd` 与之类似。
*注意* : 相较 v1.0 版本的设计文档,这里简化了较多的参数,特说明如下:
- 不使用 `data_format`
分析目前 pooling 接口主要源文件 `python/paddle/nn/functional/pooling.py`,以 `max_pool2d` 为例:
- 主要涉及两个底层算子: `max_pool2d_with_index` 和 `pool2d`
- 其中 `max_pool2d_with_index` 可以返回 `mask`,`pool2d` 不可以返回 `mask`
- 其中 `max_pool2d_with_index` 不支持 `data_format`,`pool2d` 支持 `data_format`
因此,当使用 `return_mask` 返回 `mask` 时,`data_format must be set to NCHW`。
没有一个算子能够完整支持这两个参数,这是目前 pooling 底层算子较大的矛盾。
由于设计方案以 `共用 adaptive max pooling 底层算子` 为主要设计思路,所以,这里参考 `adaptive max pooling` 的接口:
`adaptive_max_pool2d(x, output_size, return_mask=False, name=None)`
不使用 `data_format` 参数。
- 移除 `pseudo_random`, `overlapping`, `seed`
参考 `PyTorch` 的设计方案,这里将只使用 `伪` 随机的方式生成池化序列,并在 c++ 算子内部实现。
*注意* : 相较 v2.0 版本的设计文档,这里增加多个参数,特说明如下:
- `kernel_size`
此参数默认为 `None`,表示使用 `disjoint(non-overlapping)` 模式。
当此参数不为 `None` 时,使用 `overlapping` 模式,与 PyTorch 的实现保持一致。此处参考 Fractional Max-Pool 作者 Benjamin Graham 的解释:
> Hello. My original implementation (for sparse ConvNets) generated regions using this code:<https://github.com/btgraham/SparseConvNet-archived/blob/bdde325c28f64b895cebfdbe301a2ddca7870174/SparseConvNet/Regions.cu#L31>
并与作者提供的代码保持一致。
- `random_u`
增加随机序列所需的随机数参数,以方便进行复现。
## 底层 OP 设计
> *注意* 以下具体实现以实际代码为准。
涉及文件:
- `paddle/phi/api/yaml/ops.yaml` 算子描述及定义
``` yaml
- op : fractional_max_pool2d
args : (Tensor x, int[] output_size, int[] kernel_size = {0, 0}, float random_u = 0.0, bool return_mask = true)
output : Tensor(out), Tensor(mask)
infer_meta :
func : FractionalMaxPoolInferMeta
kernel :
func : fractional_max_pool2d
backward : fractional_max_pool2d_grad
```
增加 `bool` 类型 `fractional` 参数,默认为 `false`
- `paddle/phi/api/yaml/backward.yaml` 算子描述及定义
``` yaml
- backward_op : fractional_max_pool2d_grad
forward : fractional_max_pool2d(Tensor x, int[] output_size, int[] kernel_size = {0, 0}, float random_u = 0.0, bool return_mask = true) -> Tensor(out), Tensor(mask)
args : (Tensor x, Tensor mask, Tensor out_grad, int[] output_size, int[] kernel_size, float random_u, bool return_mask)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : fractional_max_pool2d_grad
```
增加 `bool` 类型 `fractional` 参数,默认为 `false`
- `paddle/phi/infermeta/unary.h` 算子 InferMeta
``` cpp
void FractionalMaxPoolInferMeta(const MetaTensor& x,
const std::vector<int>& output_size,
const std::vector<int>& kernel_size,
float random_u,
bool return_mask,
MetaTensor* out,
MetaTensor* mask,
MetaConfig config = MetaConfig());
```
- `paddle/phi/kernels/pool_kernel.h` 算子 Kernel
``` cpp
template <typename T, typename Context>
void FractionalMaxPool2dKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& output_size,
const std::vector<int>& kernel_size,
float random_u,
bool return_mask,
DenseTensor* out,
DenseTensor* mask);
```
- `paddle/phi/kernels/funcs/pooling.h`
``` cpp
template <typename Context, typename T1, typename T2>
class FractionalMaxPool2dFunctor {
public:
void operator()(const Context& context,
const DenseTensor& input,
const std::vector<int>& output_size,
const std::vector<int>& kernel_size,
float random_u,
bool return_mask,
DenseTensor* output,
DenseTensor* mask);
};
template <typename Context, typename T1, typename T2>
class FractionalMaxPool2dGradFunctor {
public:
void operator()(const Context& context,
const DenseTensor& output_grad,
const DenseTensor& mask,
const std::vector<int>& output_size,
const std::vector<int>& kernel_size,
float random_u,
bool return_mask,
DenseTensor* input_grad);
};
```
``` cpp
HOSTDEVICE inline float FractionalRationalU()
HOSTDEVICE inline int FractionalStartIndex()
HOSTDEVICE inline int FractionalEndIndex()
```
生成池化序列的方法。
- `paddle/phi/kernels/impl/pool_kernel_impl.h`
``` cpp
void FractionalMaxPoolRawKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& output_size,
const std::vector<int>& kernel_size,
float random_u,
bool return_mask,
DenseTensor* out,
DenseTensor* mask)
```
- `paddle/phi/kernels/pool_grad_kernel.h` 反向算子
``` cpp
template <typename T, typename Context>
void FractionalMaxPool2dGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& dout,
const std::vector<int>& output_size,
const std::vector<int>& kernel_size,
float random_u,
bool return_mask,
DenseTensor* dx);
template <typename T, typename Context>
void FractionalMaxPool3dGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& dout,