Skip to content

Commit 49c1768

Browse files
committed
Fix propagating information to projections inference in pooling operations
1 parent 8048ef5 commit 49c1768

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

lib/nn_blocks.ml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,26 @@ let%op depthwise_separable_conv2d ~label ?(kernel_size = 3) ?(stride = 1) () x =
195195

196196
(** Max pooling for 2D spatial data - reduces spatial dimensions by taking maximum values. *)
197197
let%op max_pool2d ?(stride = 2) ?(window_size = 2) () x =
198-
(* Although there is no kernel participating, we need to set the iteration size. *)
199198
Shape.set_dim wh window_size;
200199
Shape.set_dim ww window_size;
201-
x @^^ "... | stride*oh+wh, stride*ow+ww, ..c.. => ... | oh, ow, ..c.." [ "wh"; "ww" ]
200+
(* NOTE: projections inference runs per-assignment in a distinct phase from shape inference, so
201+
for it to know about the window size, we use a constant kernel = 1 to propagate the shape.
202+
We use a trick to create a shape-inferred constant tensor, equivalently we could write
203+
"NTDSL.term ~fetch_op:(Constant 1.) ()" but that's less concise. See:
204+
https://github.com/ahrefs/ocannl/discussions/381 *)
205+
x
206+
@^+ "... | stride*oh+wh, stride*ow+ww, ..c..; wh, ww => ... | oh, ow, ..c.." [ "wh"; "ww" ]
207+
(0.5 + 0.5)
202208

203209
(** Average pooling for 2D spatial data - reduces spatial dimensions by averaging values. *)
204210
let%op avg_pool2d ?(stride = 2) ?(window_size = 2) () x =
205211
Shape.set_dim wh window_size;
206212
Shape.set_dim ww window_size;
207-
let sum = x ++ "... | stride*oh+wh, stride*ow+ww, ..c.. => ... | oh, ow, ..c.." [ "wh"; "ww" ] in
213+
let sum =
214+
x
215+
+++ "... | stride*oh+wh, stride*ow+ww, ..c..; wh, ww => ... | oh, ow, ..c.." [ "wh"; "ww" ]
216+
(0.5 + 0.5)
217+
in
208218
sum /. (dim wh *. dim ww)
209219

210220
(** Global average pooling - reduces each feature map to a single value by averaging. Commonly used

0 commit comments

Comments
 (0)